From 86dc5167911bdaae0e20da7eb982f2ef66cb084c Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Mon, 16 Dec 2024 18:25:31 +0800 Subject: [PATCH] ... --- src/modules/pickle.c | 24 +++++++++++++++++++++++- src/public/cast.c | 2 -- tests/90_pickle.py | 22 +++++++++++++++++----- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/modules/pickle.c b/src/modules/pickle.c index a899667d..e2566adc 100644 --- a/src/modules/pickle.c +++ b/src/modules/pickle.c @@ -372,9 +372,22 @@ static bool pkl__write_object(PickleObject* buf, py_TValue* obj) { return true; } if(ti->is_python) { + NameDict* dict = PyObject__dict(obj->_obj); + for(int i = dict->length - 1; i >= 0; i--) { + NameDict_KV* kv = c11__at(NameDict_KV, dict, i); + if(!pkl__write_object(buf, &kv->value)) return false; + } pkl__emit_op(buf, PKL_OBJECT); pkl__emit_int(buf, obj->type); buf->used_types[obj->type] = true; + pkl__emit_int(buf, dict->length); + for(int i = 0; i < dict->length; i++) { + NameDict_KV* kv = c11__at(NameDict_KV, dict, i); + c11_sv field = py_name2sv(kv->key); + // include '\0' + PickleObject__write_bytes(buf, field.data, field.size + 1); + } + // store memo pkl__store_memo(buf, obj->_obj); return true; @@ -662,7 +675,16 @@ bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_ case PKL_OBJECT: { py_Type type = (py_Type)pkl__read_int(&p); type = pkl__fix_type(type, type_mapping); - if(!py_tpcall(type, 0, NULL)) return false; + py_newobject(py_retval(), type, -1, 0); + NameDict* dict = PyObject__dict(py_retval()->_obj); + int dict_length = pkl__read_int(&p); + for(int i = 0; i < dict_length; i++) { + py_StackRef value = py_peek(-1); + c11_sv field = {(const char*)p, strlen((const char*)p)}; + NameDict__set(dict, py_namev(field), *value); + py_pop(); + p += field.size + 1; + } py_push(py_retval()); break; } diff --git a/src/public/cast.c b/src/public/cast.c index b7ab92c0..6cda6c66 100644 --- a/src/public/cast.c +++ b/src/public/cast.c @@ -1,8 +1,6 @@ -#include "pocketpy/common/str.h" #include "pocketpy/objects/base.h" #include "pocketpy/pocketpy.h" -#include "pocketpy/common/utils.h" #include "pocketpy/objects/object.h" #include "pocketpy/interpreter/vm.h" diff --git a/tests/90_pickle.py b/tests/90_pickle.py index 8adc74ba..cd67b1c0 100644 --- a/tests/90_pickle.py +++ b/tests/90_pickle.py @@ -122,12 +122,24 @@ class A: test([A(1)]*10) class Simple: - def __init__(self): pass - def __eq__(self, other): return True - def __ne__(self, other): return False + def __init__(self, x): + self.field1 = x + self.field2 = [...] + def __eq__(self, other): return self.field1 == other.field1 + def __ne__(self, other): return self.field1 != other.field1 -test(Simple()) -test([Simple()]*10) +test(Simple(1)) +test([Simple(2)]*10) + +from dataclasses import dataclass + +@dataclass +class Data: + a: int + b: str = '2' + c: float = 3.0 + +test(Data(1)) exit()