diff --git a/3rd/msgpack/src/bindings.c b/3rd/msgpack/src/bindings.c index 99f28c9b..bfc0396b 100644 --- a/3rd/msgpack/src/bindings.c +++ b/3rd/msgpack/src/bindings.c @@ -56,8 +56,10 @@ static bool mpack_to_py(mpack_node_t node) { for(size_t i = 0; i < count; i++) { mpack_node_t key_node = mpack_node_map_key_at(node, i); mpack_node_t val_node = mpack_node_map_value_at(node, i); - if(mpack_node_type(key_node) != mpack_type_str) { - return TypeError("msgpack: key must be strings"); + mpack_type_t key_type = mpack_node_type(key_node); + if(key_type != mpack_type_str && key_type != mpack_type_int && + key_type != mpack_type_uint) { + return TypeError("msgpack: key must be string or integer"); } if(!mpack_to_py(key_node)) return false; if(!mpack_to_py(val_node)) return false; @@ -101,9 +103,14 @@ static bool py_to_mpack(py_Ref object, mpack_writer_t* writer); static bool mpack_write_dict_kv(py_Ref k, py_Ref v, void* ctx) { mpack_writer_t* writer = ctx; - if(k->type != tp_str) return TypeError("msgpack: key must be strings"); - c11_sv sv = py_tosv(k); - mpack_write_str(writer, sv.data, (size_t)sv.size); + if(k->type == tp_str) { + c11_sv sv = py_tosv(k); + mpack_write_str(writer, sv.data, (size_t)sv.size); + } else if(k->type == tp_int) { + mpack_write_int(writer, py_toint(k)); + } else { + return TypeError("msgpack: key must be string or integer"); + } bool ok = py_to_mpack(v, writer); if(!ok) mpack_write_nil(writer); return ok; @@ -160,7 +167,10 @@ static bool msgpack_dumps(int argc, py_Ref argv) { mpack_writer_init_growable(&writer, &data, &size); bool ok = py_to_mpack(argv, &writer); if(mpack_writer_destroy(&writer) != mpack_ok) { assert(false); } - if(!ok) return false; + if(!ok) { + MPACK_FREE(data); + return false; + } assert(size <= INT32_MAX); unsigned char* byte_data = py_newbytes(py_retval(), (int)size); memcpy(byte_data, data, size); diff --git a/tests/723_msgpack.py b/tests/723_msgpack.py index f7781c3e..12cf2beb 100644 --- a/tests/723_msgpack.py +++ b/tests/723_msgpack.py @@ -65,11 +65,8 @@ assert msgpack.dumps([]) == b'\x90' assert msgpack.dumps([1, 2, 3]) == b'\x93\x01\x02\x03' assert msgpack.dumps([1]) == b'\x91\x01' -try: - msgpack.dumps({1: 2}) - assert False -except TypeError: - assert True +_o = msgpack.dumps({1: 2}) +assert msgpack.loads(_o) == {1: 2} try: msgpack.dumps(type)