Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions ext/zstdruby/zstdruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
return output;
}

static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) {
static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs, size_t* consumed) {
VALUE out = rb_str_buf_new(0);
size_t cap = ZSTD_DStreamOutSize();
char *buf = ALLOC_N(char, cap);
Expand All @@ -64,11 +64,14 @@ static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t
}
}
xfree(buf);
if (consumed) {
*consumed = in.pos;
}
return out;
}

static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) {
return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil);
return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil, NULL);
}

static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
Expand All @@ -84,6 +87,9 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
const uint32_t ZSTD_MAGIC = 0xFD2FB528U;
const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */

VALUE result = Qnil;
ZSTD_DCtx *dctx = NULL;

while (off + 4 <= in_size) {
uint32_t magic = (uint32_t)in[off]
| ((uint32_t)in[off+1] << 8)
Expand All @@ -103,23 +109,43 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
}

if (magic == ZSTD_MAGIC) {
ZSTD_DCtx *dctx = ZSTD_createDCtx();
if (!dctx) {
rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed");
if (dctx == NULL) {
dctx = ZSTD_createDCtx();
if (!dctx) {
rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed");
}
}

VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs);
size_t consumed = 0;
VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs, &consumed);
if (result == Qnil) {
/* First frame becomes the accumulator, avoiding a copy of its
(potentially large) output in the common single-frame case. */
result = out;
} else {
rb_str_cat(result, RSTRING_PTR(out), RSTRING_LEN(out));
}

ZSTD_freeDCtx(dctx);
RB_GC_GUARD(input_value);
return out;
if (consumed == 0) {
/* Guard against a non-advancing frame to avoid an infinite loop. */
break;
}
off += consumed;
continue;
}

off += 1;
}

if (dctx != NULL) {
ZSTD_freeDCtx(dctx);
}

RB_GC_GUARD(input_value);
rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)");
if (result == Qnil) {
rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)");
}
return result;
}

static void free_cdict(void *dict)
Expand Down
13 changes: 13 additions & 0 deletions spec/zstd-ruby_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ def to_str
expect(Zstd.decompress(res)).to eq(large_strings * 3)
end

it 'should decompress concatenated frames' do
a = Zstd.compress("Hello, ")
b = Zstd.compress("World!")
expect(Zstd.decompress(a + b)).to eq("Hello, World!")
end

it 'should decompress three or more concatenated frames' do
a = Zstd.compress("Hello, ")
b = Zstd.compress("World!")
c = Zstd.compress("!!!")
expect(Zstd.decompress(a + b + c)).to eq("Hello, World!!!!")
end

it 'should raise exception with unsupported object' do
expect { Zstd.decompress(Object.new) }.to raise_error(TypeError)
end
Expand Down