From 79f18f8dc54f9c93ff4f8d1e03e90b2880a20b60 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sun, 15 Dec 2024 20:57:14 +0800 Subject: [PATCH] improve `pickle` --- include/pocketpy/config.h | 3 + src/modules/pickle.c | 373 ++++++++++++++++++++++++-------------- src/public/cast.c | 11 +- src/public/modules.c | 1 + tests/90_pickle.py | 35 +++- tests/98_lz4.py | 2 +- 6 files changed, 286 insertions(+), 139 deletions(-) diff --git a/include/pocketpy/config.h b/include/pocketpy/config.h index 623ff1af..b8f06d33 100644 --- a/include/pocketpy/config.h +++ b/include/pocketpy/config.h @@ -36,6 +36,9 @@ // (not recommended to change this) #define PK_MAX_CO_VARNAMES 64 +// This is the maximum character length of a module path +#define PK_MAX_MODULE_PATH_LEN 63 + #ifdef _WIN32 #define PK_PLATFORM_SEP '\\' #else diff --git a/src/modules/pickle.c b/src/modules/pickle.c index c87e0346..4c3c3637 100644 --- a/src/modules/pickle.c +++ b/src/modules/pickle.c @@ -24,44 +24,50 @@ typedef enum { PKL_VEC2I, PKL_VEC3I, PKL_TYPE, PKL_ARRAY2D, + PKL_TVALUE, PKL_EOF, // clang-format on } PickleOp; typedef struct { + bool* used_types; + int used_types_length; c11_smallmap_p2i memo; c11_vector /*T=char*/ codes; } PickleObject; -typedef struct { - uint16_t memo_length; -} PickleObjectHeader; - static void PickleObject__ctor(PickleObject* self) { + self->used_types_length = pk_current_vm->types.length; + self->used_types = malloc(self->used_types_length); + memset(self->used_types, 0, self->used_types_length); c11_smallmap_p2i__ctor(&self->memo); c11_vector__ctor(&self->codes, sizeof(char)); } static void PickleObject__dtor(PickleObject* self) { + free(self->used_types); c11_smallmap_p2i__dtor(&self->memo); c11_vector__dtor(&self->codes); } -static bool PickleObject__py_submit(PickleObject* self, py_OutRef out) { - unsigned char* data = self->codes.data; - PickleObjectHeader* p = - (PickleObjectHeader*)py_newbytes(out, sizeof(PickleObjectHeader) + self->codes.length); - if(self->memo.length >= UINT16_MAX) c11__abort("PickleObject__py_submit(): memo overflow"); - p->memo_length = (uint16_t)self->memo.length; - memcpy(p + 1, data, self->codes.length); - PickleObject__dtor(self); - return true; -} +static bool PickleObject__py_submit(PickleObject* self, py_OutRef out); static void PickleObject__write_bytes(PickleObject* buf, const void* data, int size) { c11_vector__extend(char, &buf->codes, data, size); } +static void c11_sbuf__write_type_path(c11_sbuf* path_buf, py_Type type) { + py_TypeInfo* ti = pk__type_info(type); + if(py_isnil(&ti->module)) { + c11_sbuf__write_cstr(path_buf, py_name2str(ti->name)); + return; + } + const char* mod_name = py_tostr(py_getdict(&ti->module, __name__)); + c11_sbuf__write_cstr(path_buf, mod_name); + c11_sbuf__write_char(path_buf, '.'); + c11_sbuf__write_cstr(path_buf, py_name2str(ti->name)); +} + static void pkl__emit_op(PickleObject* buf, PickleOp op) { c11_vector__push(char, &buf->codes, op); } @@ -97,22 +103,10 @@ static py_i64 pkl__read_int(const unsigned char** p) { (*p)++; switch(op) { // clang-format off - case PKL_INT_0: return 0; - case PKL_INT_1: return 1; - case PKL_INT_2: return 2; - case PKL_INT_3: return 3; - case PKL_INT_4: return 4; - case PKL_INT_5: return 5; - case PKL_INT_6: return 6; - case PKL_INT_7: return 7; - case PKL_INT_8: return 8; - case PKL_INT_9: return 9; - case PKL_INT_10: return 10; - case PKL_INT_11: return 11; - case PKL_INT_12: return 12; - case PKL_INT_13: return 13; - case PKL_INT_14: return 14; - case PKL_INT_15: return 15; + case PKL_INT_0: return 0; case PKL_INT_1: return 1; case PKL_INT_2: return 2; case PKL_INT_3: return 3; + case PKL_INT_4: return 4; case PKL_INT_5: return 5; case PKL_INT_6: return 6; case PKL_INT_7: return 7; + case PKL_INT_8: return 8; case PKL_INT_9: return 9; case PKL_INT_10: return 10; case PKL_INT_11: return 11; + case PKL_INT_12: return 12; case PKL_INT_13: return 13; case PKL_INT_14: return 14; case PKL_INT_15: return 15; // clang-format on case PKL_INT8: { int8_t val; @@ -138,13 +132,6 @@ static py_i64 pkl__read_int(const unsigned char** p) { } } -const static char* pkl__read_cstr(const unsigned char** p) { - const char* p_str = (const char*)*p; - int length = strlen(p_str); - *p += length + 1; // include '\0' - return p_str; -} - static bool pickle_loads(int argc, py_Ref argv) { PY_CHECK_ARGC(1); PY_CHECK_ARG_TYPE(0, tp_bytes); @@ -165,11 +152,11 @@ void pk__add_module_pickle() { py_bindfunc(mod, "dumps", pickle_dumps); } -static bool pickle__write_object(PickleObject* buf, py_TValue* obj); +static bool pkl__write_object(PickleObject* buf, py_TValue* obj); -static bool pickle__write_array(PickleObject* buf, PickleOp op, py_TValue* arr, int length) { +static bool pkl__write_array(PickleObject* buf, PickleOp op, py_TValue* arr, int length) { for(int i = 0; i < length; i++) { - bool ok = pickle__write_object(buf, arr + i); + bool ok = pkl__write_object(buf, arr + i); if(!ok) return false; } pkl__emit_op(buf, op); @@ -177,36 +164,44 @@ static bool pickle__write_array(PickleObject* buf, PickleOp op, py_TValue* arr, return true; } -static bool pickle__write_dict_kv(py_Ref k, py_Ref v, void* ctx) { +static bool pkl__write_dict_kv(py_Ref k, py_Ref v, void* ctx) { PickleObject* buf = (PickleObject*)ctx; - if(!pickle__write_object(buf, k)) return false; - if(!pickle__write_object(buf, v)) return false; + if(!pkl__write_object(buf, k)) return false; + if(!pkl__write_object(buf, v)) return false; return true; } -static bool pickle__write_object(PickleObject* buf, py_TValue* obj) { - if(obj->is_ptr) { - void* memo_key = obj->_obj; - int index = c11_smallmap_p2i__get(&buf->memo, memo_key, -1); - if(index != -1) { - pkl__emit_op(buf, PKL_MEMO_GET); - pkl__emit_int(buf, index); - return true; - } +static bool pkl__try_memo(PickleObject* buf, PyObject* memo_key) { + int index = c11_smallmap_p2i__get(&buf->memo, memo_key, -1); + if(index != -1) { + pkl__emit_op(buf, PKL_MEMO_GET); + pkl__emit_int(buf, index); + return true; } + return false; +} + +static void pkl__store_memo(PickleObject* buf, PyObject* memo_key) { + int index = buf->memo.length; + c11_smallmap_p2i__set(&buf->memo, memo_key, index); + pkl__emit_op(buf, PKL_MEMO_SET); + pkl__emit_int(buf, index); +} + +static bool pkl__write_object(PickleObject* buf, py_TValue* obj) { switch(obj->type) { case tp_NoneType: { pkl__emit_op(buf, PKL_NONE); - break; + return true; } case tp_ellipsis: { pkl__emit_op(buf, PKL_ELLIPSIS); - break; + return true; } case tp_int: { py_i64 val = obj->_i64; pkl__emit_int(buf, val); - break; + return true; } case tp_float: { py_f64 val = obj->_f64; @@ -218,64 +213,90 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) { pkl__emit_op(buf, PKL_FLOAT64); PickleObject__write_bytes(buf, &val, 8); } - break; + return true; } case tp_bool: { bool val = obj->_bool; pkl__emit_op(buf, val ? PKL_TRUE : PKL_FALSE); - break; + return true; } case tp_str: { - pkl__emit_op(buf, PKL_STRING); - c11_sv sv = py_tosv(obj); - pkl__emit_int(buf, sv.size); - PickleObject__write_bytes(buf, sv.data, sv.size); - break; + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + pkl__emit_op(buf, PKL_STRING); + c11_sv sv = py_tosv(obj); + pkl__emit_int(buf, sv.size); + PickleObject__write_bytes(buf, sv.data, sv.size); + } + pkl__store_memo(buf, obj->_obj); + return true; } case tp_bytes: { - pkl__emit_op(buf, PKL_BYTES); - int size; - unsigned char* data = py_tobytes(obj, &size); - pkl__emit_int(buf, size); - PickleObject__write_bytes(buf, data, size); - break; + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + pkl__emit_op(buf, PKL_BYTES); + int size; + unsigned char* data = py_tobytes(obj, &size); + pkl__emit_int(buf, size); + PickleObject__write_bytes(buf, data, size); + } + pkl__store_memo(buf, obj->_obj); + return true; } case tp_list: { - bool ok = pickle__write_array(buf, PKL_BUILD_LIST, py_list_data(obj), py_list_len(obj)); - if(!ok) return false; - break; + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + bool ok = + pkl__write_array(buf, PKL_BUILD_LIST, py_list_data(obj), py_list_len(obj)); + if(!ok) return false; + } + pkl__store_memo(buf, obj->_obj); + return true; } case tp_tuple: { - bool ok = - pickle__write_array(buf, PKL_BUILD_TUPLE, py_tuple_data(obj), py_tuple_len(obj)); - if(!ok) return false; - break; + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + bool ok = + pkl__write_array(buf, PKL_BUILD_TUPLE, py_tuple_data(obj), py_tuple_len(obj)); + if(!ok) return false; + } + pkl__store_memo(buf, obj->_obj); + return true; } case tp_dict: { - bool ok = py_dict_apply(obj, pickle__write_dict_kv, (void*)buf); - if(!ok) return false; - pkl__emit_op(buf, PKL_BUILD_DICT); - pkl__emit_int(buf, py_dict_len(obj)); - break; + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + bool ok = py_dict_apply(obj, pkl__write_dict_kv, (void*)buf); + if(!ok) return false; + pkl__emit_op(buf, PKL_BUILD_DICT); + pkl__emit_int(buf, py_dict_len(obj)); + } + pkl__store_memo(buf, obj->_obj); + return true; } case tp_vec2: { c11_vec2 val = py_tovec2(obj); pkl__emit_op(buf, PKL_VEC2); PickleObject__write_bytes(buf, &val, sizeof(c11_vec2)); - break; + return true; } case tp_vec3: { c11_vec3 val = py_tovec3(obj); pkl__emit_op(buf, PKL_VEC3); PickleObject__write_bytes(buf, &val, sizeof(c11_vec3)); - break; + return true; } case tp_vec2i: { c11_vec2i val = py_tovec2i(obj); pkl__emit_op(buf, PKL_VEC2I); pkl__emit_int(buf, val.x); pkl__emit_int(buf, val.y); - break; + return true; } case tp_vec3i: { c11_vec3i val = py_tovec3i(obj); @@ -283,53 +304,51 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) { pkl__emit_int(buf, val.x); pkl__emit_int(buf, val.y); pkl__emit_int(buf, val.z); - break; + return true; } case tp_type: { pkl__emit_op(buf, PKL_TYPE); - py_TypeInfo* ti = pk__type_info(py_totype(obj)); - const char* mod_name = py_tostr(py_getdict(&ti->module, __name__)); - c11_sbuf path_buf; - c11_sbuf__ctor(&path_buf); - c11_sbuf__write_cstr(&path_buf, mod_name); - c11_sbuf__write_cstr(&path_buf, "@"); - c11_sbuf__write_cstr(&path_buf, py_name2str(ti->name)); - c11_string* path = c11_sbuf__submit(&path_buf); - // include '\0' - PickleObject__write_bytes(buf, path->data, path->size + 1); - c11_string__delete(path); - break; + py_Type type = py_totype(obj); + buf->used_types[type] = true; + pkl__emit_int(buf, type); + return true; } case tp_array2d: { - c11_array2d* arr = py_touserdata(obj); - for(int i = 0; i < arr->numel; i++) { - if(arr->data[i].is_ptr) - return TypeError( - "'array2d' object is not picklable because it contains heap-allocated objects"); + if(pkl__try_memo(buf, obj->_obj)) + return true; + else { + c11_array2d* arr = py_touserdata(obj); + for(int i = 0; i < arr->numel; i++) { + if(arr->data[i].is_ptr) + return TypeError( + "'array2d' object is not picklable because it contains heap-allocated objects"); + buf->used_types[arr->data[i].type] = true; + } + pkl__emit_op(buf, PKL_ARRAY2D); + pkl__emit_int(buf, arr->n_cols); + pkl__emit_int(buf, arr->n_rows); + PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue)); } - pkl__emit_op(buf, PKL_ARRAY2D); - pkl__emit_int(buf, arr->n_cols); - pkl__emit_int(buf, arr->n_rows); - // TODO: fix type index which is not stable - PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue)); - break; + pkl__store_memo(buf, obj->_obj); + return true; + } + default: { + if(!obj->is_ptr) { + pkl__emit_op(buf, PKL_TVALUE); + PickleObject__write_bytes(buf, obj, sizeof(py_TValue)); + buf->used_types[obj->type] = true; + return true; + } + return TypeError("'%t' object is not picklable", obj->type); } - default: return TypeError("'%t' object is not picklable", obj->type); } - if(obj->is_ptr) { - void* memo_key = obj->_obj; - int index = buf->memo.length; - c11_smallmap_p2i__set(&buf->memo, memo_key, index); - pkl__emit_op(buf, PKL_MEMO_SET); - pkl__emit_int(buf, index); - } - return true; + c11__unreachable(); } bool py_pickle_dumps(py_Ref val) { PickleObject buf; PickleObject__ctor(&buf); - bool ok = pickle__write_object(&buf, val); + bool ok = pkl__write_object(&buf, val); if(!ok) { PickleObject__dtor(&buf); return false; @@ -338,12 +357,74 @@ bool py_pickle_dumps(py_Ref val) { return PickleObject__py_submit(&buf, py_retval()); } +static py_Type pkl__header_find_type(c11_sv path) { + int sep_index = c11_sv__rindex(path, '.'); + if(sep_index == -1) return py_gettype(NULL, py_namev(path)); + c11_sv mod_name = c11_sv__slice2(path, 0, sep_index); + c11_sv name = c11_sv__slice(path, sep_index + 1); + char buf[PK_MAX_MODULE_PATH_LEN + 1]; + memcpy(buf, mod_name.data, mod_name.size); + buf[mod_name.size] = '\0'; + return py_gettype(buf, py_namev(name)); +} + +static c11_sv pkl__header_read_sv(const unsigned char** p, char sep) { + c11_sv text; + text.data = (const char*)*p; + const char* p_end = strchr(text.data, sep); + assert(p_end != NULL); + text.size = p_end - text.data; + *p = (const unsigned char*)p_end + 1; + return text; +} + +static py_i64 pkl__header_read_int(const unsigned char** p, char sep) { + c11_sv text = pkl__header_read_sv(p, sep); + py_i64 out; + IntParsingResult res = c11__parse_uint(text, &out, 10); + assert(res == IntParsing_SUCCESS); + return out; +} + +bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_n2i* type_mapping); + bool py_pickle_loads(const unsigned char* data, int size) { - PickleObjectHeader* header = (PickleObjectHeader*)data; - const unsigned char* p = (const unsigned char*)(header + 1); + const unsigned char* p = data; + + c11_smallmap_n2i type_mapping; + c11_smallmap_n2i__ctor(&type_mapping); + + while(true) { + if(*p == '\n') { + p++; + break; + } + py_Type type = pkl__header_read_int(&p, '('); + c11_sv path = pkl__header_read_sv(&p, ')'); + py_Type new_type = pkl__header_find_type(path); + if(new_type == 0) { + c11_smallmap_n2i__dtor(&type_mapping); + return ImportError("cannot find type '%v'", path); + } + if(type != new_type) c11_smallmap_n2i__set(&type_mapping, type, new_type); + } + + int memo_length = pkl__header_read_int(&p, '\n'); + bool ok = py_pickle_loads_body(p, memo_length, &type_mapping); + c11_smallmap_n2i__dtor(&type_mapping); + return ok; +} + +static py_Type pkl__fix_type(py_Type type, c11_smallmap_n2i* type_mapping) { + int new_type = c11_smallmap_n2i__get(type_mapping, type, -1); + if(new_type != -1) return (py_Type)new_type; + return type; +} + +bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_n2i* type_mapping) { py_StackRef p0 = py_peek(0); py_StackRef memo = py_pushtmp(); - py_newtuple(memo, header->memo_length); + py_newtuple(memo, memo_length); while(true) { PickleOp op = (PickleOp)*p; p++; @@ -504,18 +585,9 @@ bool py_pickle_loads(const unsigned char* data, int size) { break; } case PKL_TYPE: { - const char* path = pkl__read_cstr(&p); - char* sep_index = strchr(path, '@'); - assert(sep_index != NULL); - *sep_index = '\0'; - const char* mod_name = path; - const char* type_name = sep_index + 1; - py_Type t = py_gettype(mod_name, py_name(type_name)); - *sep_index = '@'; - if(t == 0) { - return ImportError("cannot import '%s' from '%s'", type_name, mod_name); - } - py_push(py_tpobject(t)); + py_Type type = (py_Type)pkl__read_int(&p); + type = pkl__fix_type(type, type_mapping); + py_push(py_tpobject(type)); break; } case PKL_ARRAY2D: { @@ -524,9 +596,19 @@ bool py_pickle_loads(const unsigned char* data, int size) { c11_array2d* arr = py_newarray2d(py_pushtmp(), n_cols, n_rows); int total_size = arr->numel * sizeof(py_TValue); memcpy(arr->data, p, total_size); + for(int i = 0; i < arr->numel; i++) { + arr->data[i].type = pkl__fix_type(arr->data[i].type, type_mapping); + } p += total_size; break; } + case PKL_TVALUE: { + py_TValue* tmp = py_pushtmp(); + memcpy(tmp, p, sizeof(py_TValue)); + tmp->type = pkl__fix_type(tmp->type, type_mapping); + p += sizeof(py_TValue); + break; + } case PKL_EOF: { // [memo, obj] if(py_peek(0) - p0 != 2) return ValueError("invalid pickle data"); @@ -537,6 +619,33 @@ bool py_pickle_loads(const unsigned char* data, int size) { default: c11__unreachable(); } } + c11__unreachable(); +} + +static bool PickleObject__py_submit(PickleObject* self, py_OutRef out) { + c11_sbuf cleartext; + c11_sbuf__ctor(&cleartext); + // line 1: type mappping + for(py_Type type = 0; type < self->used_types_length; type++) { + if(self->used_types[type]) { + c11_sbuf__write_int(&cleartext, type); + c11_sbuf__write_char(&cleartext, '('); + c11_sbuf__write_type_path(&cleartext, type); + c11_sbuf__write_char(&cleartext, ')'); + } + } + c11_sbuf__write_char(&cleartext, '\n'); + // line 2: memo length + c11_sbuf__write_int(&cleartext, self->memo.length); + c11_sbuf__write_char(&cleartext, '\n'); + // -------------------------------------------------- // + c11_string* header = c11_sbuf__submit(&cleartext); + int total_size = header->size + self->codes.length; + unsigned char* p = py_newbytes(py_retval(), total_size); + memcpy(p, header->data, header->size); + memcpy(p + header->size, self->codes.data, self->codes.length); + c11_string__delete(header); + PickleObject__dtor(self); return true; } diff --git a/src/public/cast.c b/src/public/cast.c index 21c113dd..b7ab92c0 100644 --- a/src/public/cast.c +++ b/src/public/cast.c @@ -24,7 +24,7 @@ bool py_castfloat(py_Ref self, double* out) { } } -bool py_castfloat32(py_Ref self, float *out){ +bool py_castfloat32(py_Ref self, float* out) { switch(self->type) { case tp_int: *out = (float)self->_i64; return true; case tp_float: *out = (float)self->_f64; return true; @@ -77,8 +77,13 @@ bool py_issubclass(py_Type derived, py_Type base) { py_Type py_typeof(py_Ref self) { return self->type; } py_Type py_gettype(const char* module, py_Name name) { - py_Ref mod = py_getmodule(module); - if(!mod) return 0; + py_Ref mod; + if(module != NULL) { + mod = py_getmodule(module); + if(!mod) return 0; + } else { + mod = &pk_current_vm->builtins; + } py_Ref object = py_getdict(mod, name); if(object && py_istype(object, tp_type)) return py_totype(object); return 0; diff --git a/src/public/modules.c b/src/public/modules.c index 7fb03530..a05e54f1 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -23,6 +23,7 @@ void py_setglobal(py_Name name, py_Ref val) { py_setdict(&pk_current_vm->main, n py_Ref py_newmodule(const char* path) { ManagedHeap* heap = &pk_current_vm->heap; + if(strlen(path) > PK_MAX_MODULE_PATH_LEN) c11__abort("module path too long: %s", path); py_Ref r0 = py_pushtmp(); py_Ref r1 = py_pushtmp(); diff --git a/tests/90_pickle.py b/tests/90_pickle.py index a3976097..8b3cc3d5 100644 --- a/tests/90_pickle.py +++ b/tests/90_pickle.py @@ -33,9 +33,9 @@ test(vec3i) # PKL_TYPE print('-'*50) from array2d import array2d -a = array2d[int].fromlist([ - [1, 2, 3], - [4, 5, 6] +a = array2d[int | bool | vec2i].fromlist([ + [1, 2, vec2i.LEFT], + [4, True, 6] ]) a_encoded = pkl.dumps(a) print(a_encoded) @@ -72,6 +72,35 @@ assert b is not a assert b[0] is b[2] assert b[1] is b[3] +from pkpy import TValue + +class Base(TValue[int]): + def __eq__(self, other): + return self.value == other.value + + def __ne__(self, other): + return self.value != other.value + +class TVal(Base): pass # type: ignore + +test(TVal(1)) + +old_bytes = pkl.dumps(TVal(1)) +print(old_bytes) + +# re-define the class so it will have a new type id +class TVal(Base): pass +# see if we can still load the old data +decoded = pkl.loads(old_bytes) +assert decoded == TVal(1) +print(pkl.dumps(decoded)) + +# test array2d with TValue +a = array2d[TVal].fromlist([ + [TVal(1), TVal(2)], + [TVal(3), 1]]) +test(a) + exit() from pickle import dumps, loads, _wrap, _unwrap diff --git a/tests/98_lz4.py b/tests/98_lz4.py index 4178aaee..fbe95e30 100644 --- a/tests/98_lz4.py +++ b/tests/98_lz4.py @@ -26,6 +26,6 @@ for i in range(100): ratio = test(gen_data()) # print(f'compression ratio: {ratio:.2f}') -# test 1GB random data +# test 64MB random data (require 1GB list[int] buffer) rnd = [random.randint(0, 255) for _ in range(1024*1024*1024//16)] test(bytes(rnd))