Skip to content

Add raw_decode method to JSON and MsgPack decoders #821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ JSON
:members: encode, encode_lines, encode_into

.. autoclass:: Decoder
:members: decode, decode_lines
:members: decode, decode_lines, raw_decode

.. autofunction:: encode

Expand All @@ -80,7 +80,7 @@ MessagePack
:members: encode, encode_into

.. autoclass:: Decoder
:members: decode
:members: decode, raw_decode

.. autoclass:: Ext
:members:
Expand Down
130 changes: 130 additions & 0 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -16128,11 +16128,71 @@ Decoder_decode(Decoder *self, PyObject *const *args, Py_ssize_t nargs)
return NULL;
}

PyDoc_STRVAR(Decoder_raw_decode__doc__,
"raw_decode(self, buf)\n"
"--\n"
"\n"
"Deserialize an object from MessagePack, allowing trailing data.\n"
"\n"
"Parameters\n"
"----------\n"
"buf : bytes-like\n"
" The message to decode.\n"
"\n"
"Returns\n"
"-------\n"
"obj_and_index : 2-tuple of Any and int\n"
" A tuple containing the deserialized object, as well as the index into\n"
" the input at which the object ended.\n"
);
static PyObject*
Decoder_raw_decode(Decoder *self, PyObject *const *args, Py_ssize_t nargs)
{
if (!check_positional_nargs(nargs, 1, 1)) {
return NULL;
}

DecoderState state = {
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.ext_hook = self->ext_hook
};

Py_buffer buffer;
buffer.buf = NULL;
if (PyObject_GetBuffer(args[0], &buffer, PyBUF_CONTIG_RO) >= 0) {
state.buffer_obj = args[0];
state.input_start = buffer.buf;
state.input_pos = buffer.buf;
state.input_end = state.input_pos + buffer.len;

PyObject *res = mpack_decode(&state, state.type, NULL, false);

if (res != NULL) {
PyObject *tup = Py_BuildValue(
"(On)", res,
(Py_ssize_t)(state.input_pos - state.input_start)
);
Py_CLEAR(res);
res = tup;
}

PyBuffer_Release(&buffer);
return res;
}
return NULL;
}

static struct PyMethodDef Decoder_methods[] = {
{
"decode", (PyCFunction) Decoder_decode, METH_FASTCALL,
Decoder_decode__doc__,
},
{
"raw_decode", (PyCFunction) Decoder_raw_decode, METH_FASTCALL,
Decoder_raw_decode__doc__,
},
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
{NULL, NULL} /* sentinel */
};
Expand Down Expand Up @@ -19174,6 +19234,72 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na
return NULL;
}

PyDoc_STRVAR(JSONDecoder_raw_decode__doc__,
"raw_decode(self, buf)\n"
"--\n"
"\n"
"Deserialize an object from JSON, allowing trailing data.\n"
"\n"
"Parameters\n"
"----------\n"
"buf : bytes-like or str\n"
" The message to decode.\n"
"\n"
"Returns\n"
"-------\n"
"obj_and_index : 2-tuple of Any and int\n"
" A tuple containing the deserialized object, as well as the index into\n"
" the input at which the object ended.\n"
);
static PyObject*
JSONDecoder_raw_decode(
JSONDecoder *self,
PyObject *const *args,
Py_ssize_t nargs
) {
if (!check_positional_nargs(nargs, 1, 1)) {
return NULL;
}

JSONDecoderState state = {
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.float_hook = self->float_hook,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
};

Py_buffer buffer;
buffer.buf = NULL;
if (ms_get_buffer(args[0], &buffer) >= 0) {

state.buffer_obj = args[0];
state.input_start = buffer.buf;
state.input_pos = buffer.buf;
state.input_end = state.input_pos + buffer.len;

PyObject *res = json_decode(&state, state.type, NULL);

if (res != NULL) {
PyObject *tup = Py_BuildValue(
"(On)", res,
(Py_ssize_t)(state.input_pos - state.input_start)
);
Py_CLEAR(res);
res = tup;
}

ms_release_buffer(&buffer);

PyMem_Free(state.scratch);
return res;
}

return NULL;
}

static struct PyMethodDef JSONDecoder_methods[] = {
{
"decode", (PyCFunction) JSONDecoder_decode, METH_FASTCALL,
Expand All @@ -19183,6 +19309,10 @@ static struct PyMethodDef JSONDecoder_methods[] = {
"decode_lines", (PyCFunction) JSONDecoder_decode_lines, METH_FASTCALL,
JSONDecoder_decode_lines__doc__,
},
{
"raw_decode", (PyCFunction) JSONDecoder_raw_decode, METH_FASTCALL,
JSONDecoder_raw_decode__doc__,
},
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
{NULL, NULL} /* sentinel */
};
Expand Down
1 change: 1 addition & 0 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Decoder(Generic[T]):
) -> None: ...
def decode(self, buf: Union[Buffer, str], /) -> T: ...
def decode_lines(self, buf: Union[Buffer, str], /) -> list[T]: ...
def raw_decode(self, buf: Union[Buffer, str], /) -> Tuple[T, int]: ...

@overload
def decode(
Expand Down
2 changes: 2 additions & 0 deletions msgspec/msgpack.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from typing import (
Generic,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -58,6 +59,7 @@ class Decoder(Generic[T]):
ext_hook: ext_hook_sig = None,
) -> None: ...
def decode(self, buf: Buffer, /) -> T: ...
def raw_decode(self, buf: Buffer, /) -> Tuple[T, int]: ...

class Encoder:
enc_hook: enc_hook_sig
Expand Down
29 changes: 29 additions & 0 deletions tests/basic_typing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,22 @@ def check_msgpack_Decoder_decode_type_comment() -> None:
reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ


def check_msgpack_Decoder_raw_decode_any() -> None:
dec = msgspec.msgpack.Decoder()
b = msgspec.msgpack.encode([1, 2, 3])
o = dec.raw_decode(b)

reveal_type(o) # assert "tuple" in typ.lower() and "Any" in typ and "int" in typ


def check_msgpack_Decoder_raw_decode_typed() -> None:
dec = msgspec.msgpack.Decoder(int)
b = msgspec.msgpack.encode(1)
o = dec.raw_decode(b)

reveal_type(o) # assert ("Tuple" in typ or "tuple" in typ) and typ.count("int") == 2


def check_msgpack_decode_any() -> None:
b = msgspec.msgpack.encode([1, 2, 3])
o = msgspec.msgpack.decode(b)
Expand Down Expand Up @@ -814,6 +830,19 @@ def check_json_Decoder_decode_lines_typed() -> None:
reveal_type(o) # assert "list" in typ.lower() and "int" in typ.lower()


def check_json_Decoder_raw_decode_any() -> None:
dec = msgspec.json.Decoder()
o = dec.raw_decode(b'1')

reveal_type(o) # assert "tuple" in typ.lower() and "any" in typ.lower() and "int" in typ.lower()


def check_json_Decoder_raw_decode_typed() -> None:
dec = msgspec.json.Decoder(int)
o = dec.raw_decode(b'1')
reveal_type(o) # assert "tuple" in typ.lower() and typ.lower().count("int") == 2


def check_json_decode_any() -> None:
b = msgspec.json.encode([1, 2, 3])
o = msgspec.json.decode(b)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,30 @@ def test_decode_lines_bad_call(self):
with pytest.raises(TypeError):
dec.decode(1)

def test_raw_decode(self):
dec = msgspec.json.Decoder()

obj, index = dec.raw_decode(b"[1, 2, 3]trailing invalid")
assert obj == [1, 2, 3]
assert index == len(b"[1, 2, 3]")

def test_raw_decode_malformed(self):
dec = msgspec.json.Decoder()
with pytest.raises(msgspec.DecodeError, match="malformed"):
dec.raw_decode(b'{"x": efg')

def test_raw_decode_bad_call(self):
dec = msgspec.json.Decoder()

with pytest.raises(TypeError):
dec.raw_decode()

with pytest.raises(TypeError):
dec.raw_decode("{}", 2)

with pytest.raises(TypeError):
dec.raw_decode(1)

def test_decoder_init_float_hook(self):
dec = msgspec.json.Decoder()
assert dec.float_hook is None
Expand Down
22 changes: 22 additions & 0 deletions tests/test_msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,28 @@ def test_decoding_large_arrays_as_keys_doesnt_preallocate(self):
with pytest.raises(msgspec.DecodeError, match="truncated"):
msgspec.msgpack.decode(b)

def test_raw_decode(self):
dec = msgspec.msgpack.Decoder()

msg = msgspec.msgpack.encode([1, 2, 3])
obj, index = dec.raw_decode(msg + b"trailing")
assert obj == [1, 2, 3]
assert index == len(msg)

def test_raw_decode_skip_invalid_submessage_raises(self):
"""Ensure errors in submessage skipping are raised"""

class Test(msgspec.Struct):
x: int

msg = msgspec.msgpack.encode({"x": 1, "y": ["one", "two", "three"]})

# Break the message
msg = msg.replace(b"three", b"tree")

with pytest.raises(msgspec.DecodeError, match="truncated"):
msgspec.msgpack.Decoder(type=Test).raw_decode(msg)


class TestTypedDecoder:
def check_unexpected_type(self, dec_type, val, msg):
Expand Down