This commit is contained in:
blueloveTH 2024-12-16 18:25:31 +08:00
parent 58b5455871
commit 86dc516791
3 changed files with 40 additions and 8 deletions

View File

@ -372,9 +372,22 @@ static bool pkl__write_object(PickleObject* buf, py_TValue* obj) {
return true;
}
if(ti->is_python) {
NameDict* dict = PyObject__dict(obj->_obj);
for(int i = dict->length - 1; i >= 0; i--) {
NameDict_KV* kv = c11__at(NameDict_KV, dict, i);
if(!pkl__write_object(buf, &kv->value)) return false;
}
pkl__emit_op(buf, PKL_OBJECT);
pkl__emit_int(buf, obj->type);
buf->used_types[obj->type] = true;
pkl__emit_int(buf, dict->length);
for(int i = 0; i < dict->length; i++) {
NameDict_KV* kv = c11__at(NameDict_KV, dict, i);
c11_sv field = py_name2sv(kv->key);
// include '\0'
PickleObject__write_bytes(buf, field.data, field.size + 1);
}
// store memo
pkl__store_memo(buf, obj->_obj);
return true;
@ -662,7 +675,16 @@ bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_
case PKL_OBJECT: {
py_Type type = (py_Type)pkl__read_int(&p);
type = pkl__fix_type(type, type_mapping);
if(!py_tpcall(type, 0, NULL)) return false;
py_newobject(py_retval(), type, -1, 0);
NameDict* dict = PyObject__dict(py_retval()->_obj);
int dict_length = pkl__read_int(&p);
for(int i = 0; i < dict_length; i++) {
py_StackRef value = py_peek(-1);
c11_sv field = {(const char*)p, strlen((const char*)p)};
NameDict__set(dict, py_namev(field), *value);
py_pop();
p += field.size + 1;
}
py_push(py_retval());
break;
}

View File

@ -1,8 +1,6 @@
#include "pocketpy/common/str.h"
#include "pocketpy/objects/base.h"
#include "pocketpy/pocketpy.h"
#include "pocketpy/common/utils.h"
#include "pocketpy/objects/object.h"
#include "pocketpy/interpreter/vm.h"

View File

@ -122,12 +122,24 @@ class A:
test([A(1)]*10)
class Simple:
def __init__(self): pass
def __eq__(self, other): return True
def __ne__(self, other): return False
def __init__(self, x):
self.field1 = x
self.field2 = [...]
def __eq__(self, other): return self.field1 == other.field1
def __ne__(self, other): return self.field1 != other.field1
test(Simple())
test([Simple()]*10)
test(Simple(1))
test([Simple(2)]*10)
from dataclasses import dataclass
@dataclass
class Data:
a: int
b: str = '2'
c: float = 3.0
test(Data(1))
exit()