mirror of
https://github.com/pocketpy/pocketpy
synced 2025-10-24 05:20:17 +00:00
pickle supports cyclic references
This commit is contained in:
parent
81feb7a245
commit
6874141b62
133
python/pickle.py
133
python/pickle.py
@ -21,7 +21,8 @@ def _find__new__(cls):
|
|||||||
assert False
|
assert False
|
||||||
|
|
||||||
class _Pickler:
|
class _Pickler:
|
||||||
def __init__(self) -> None:
|
def __init__(self, obj) -> None:
|
||||||
|
self.obj = obj
|
||||||
self.raw_memo = {} # id -> int
|
self.raw_memo = {} # id -> int
|
||||||
self.memo = [] # int -> object
|
self.memo = [] # int -> object
|
||||||
|
|
||||||
@ -31,81 +32,106 @@ class _Pickler:
|
|||||||
|
|
||||||
index = self.raw_memo.get(id(o), None)
|
index = self.raw_memo.get(id(o), None)
|
||||||
if index is not None:
|
if index is not None:
|
||||||
return ["$", index]
|
return [index]
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
index = len(self.memo)
|
index = len(self.memo)
|
||||||
self.memo.append(ret)
|
self.memo.append(ret)
|
||||||
self.raw_memo[id(o)] = index
|
self.raw_memo[id(o)] = index
|
||||||
|
|
||||||
if type(o) is list:
|
|
||||||
ret.append("list")
|
|
||||||
ret.append([self.wrap(i) for i in o])
|
|
||||||
return ["$", index]
|
|
||||||
|
|
||||||
if type(o) is tuple:
|
if type(o) is tuple:
|
||||||
ret.append("tuple")
|
ret.append("tuple")
|
||||||
ret.append([self.wrap(i) for i in o])
|
ret.append([self.wrap(i) for i in o])
|
||||||
return ["$", index]
|
return [index]
|
||||||
|
|
||||||
if type(o) is dict:
|
|
||||||
ret.append("dict")
|
|
||||||
ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
|
|
||||||
return ["$", index]
|
|
||||||
|
|
||||||
if type(o) is bytes:
|
if type(o) is bytes:
|
||||||
ret.append("bytes")
|
ret.append("bytes")
|
||||||
ret.append([o[j] for j in range(len(o))])
|
ret.append([o[j] for j in range(len(o))])
|
||||||
return ["$", index]
|
return [index]
|
||||||
|
if type(o) is list:
|
||||||
|
ret.append("list")
|
||||||
|
ret.append([self.wrap(i) for i in o])
|
||||||
|
return [index]
|
||||||
|
if type(o) is dict:
|
||||||
|
ret.append("dict")
|
||||||
|
ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
|
||||||
|
return [index]
|
||||||
|
|
||||||
_0 = o.__class__.__name__
|
_0 = o.__class__.__name__
|
||||||
|
|
||||||
if hasattr(o, "__getnewargs__"):
|
if hasattr(o, "__getnewargs__"):
|
||||||
_1 = o.__getnewargs__() # an iterable
|
_1 = o.__getnewargs__() # an iterable
|
||||||
_1 = [self.wrap(i) for i in _1]
|
_1 = [self.wrap(i) for i in _1]
|
||||||
else:
|
else:
|
||||||
_1 = None
|
_1 = None
|
||||||
if hasattr(o, "__getstate__"):
|
|
||||||
_2 = o.__getstate__()
|
if o.__dict__ is None:
|
||||||
|
_2 = None
|
||||||
else:
|
else:
|
||||||
if o.__dict__ is None:
|
_2 = {}
|
||||||
_2 = None
|
for k,v in o.__dict__.items():
|
||||||
else:
|
_2[k] = self.wrap(v)
|
||||||
_2 = {}
|
|
||||||
for k,v in o.__dict__.items():
|
|
||||||
_2[k] = self.wrap(v)
|
|
||||||
ret.append(_0)
|
ret.append(_0)
|
||||||
ret.append(_1)
|
ret.append(_1)
|
||||||
ret.append(_2)
|
ret.append(_2)
|
||||||
return ["$", index]
|
return [index]
|
||||||
|
|
||||||
|
def run_pipe(self):
|
||||||
|
o = self.wrap(self.obj)
|
||||||
|
return [o, self.memo]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _Unpickler:
|
class _Unpickler:
|
||||||
def __init__(self, memo: list) -> None:
|
def __init__(self, obj, memo: list) -> None:
|
||||||
|
self.obj = obj
|
||||||
self.memo = memo
|
self.memo = memo
|
||||||
self.unwrapped = [None] * len(memo)
|
self._unwrapped = [None] * len(memo)
|
||||||
|
|
||||||
def unwrap_ref(self, i: int):
|
def tag(self, index, o):
|
||||||
if self.unwrapped[i] is None:
|
assert self._unwrapped[index] is None
|
||||||
o = self.memo[i]
|
self._unwrapped[index] = o
|
||||||
assert type(o) is list
|
|
||||||
assert o[0] != '$'
|
|
||||||
self.unwrapped[i] = self.unwrap(o)
|
|
||||||
return self.unwrapped[i]
|
|
||||||
|
|
||||||
def unwrap(self, o):
|
def unwrap(self, o, index=None):
|
||||||
if type(o) in _BASIC_TYPES:
|
if type(o) in _BASIC_TYPES:
|
||||||
return o
|
return o
|
||||||
assert type(o) is list
|
assert type(o) is list
|
||||||
if o[0] == '$':
|
|
||||||
index = o[1]
|
# reference
|
||||||
return self.unwrap_ref(index)
|
if type(o[0]) is int:
|
||||||
if o[0] == "list":
|
assert index is None # index should be None
|
||||||
return [self.unwrap(i) for i in o[1]]
|
index = o[0]
|
||||||
|
if self._unwrapped[index] is None:
|
||||||
|
o = self.memo[index]
|
||||||
|
assert type(o) is list
|
||||||
|
assert type(o[0]) is str
|
||||||
|
self.unwrap(o, index)
|
||||||
|
assert self._unwrapped[index] is not None
|
||||||
|
return self._unwrapped[index]
|
||||||
|
|
||||||
|
# concrete reference type
|
||||||
if o[0] == "tuple":
|
if o[0] == "tuple":
|
||||||
return tuple([self.unwrap(i) for i in o[1]])
|
ret = tuple([self.unwrap(i) for i in o[1]])
|
||||||
if o[0] == "dict":
|
self.tag(index, ret)
|
||||||
return {self.unwrap(k): self.unwrap(v) for k,v in o[1]}
|
return ret
|
||||||
if o[0] == "bytes":
|
if o[0] == "bytes":
|
||||||
return bytes(o[1])
|
ret = bytes(o[1])
|
||||||
|
self.tag(index, ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if o[0] == "list":
|
||||||
|
ret = []
|
||||||
|
self.tag(index, ret)
|
||||||
|
for i in o[1]:
|
||||||
|
ret.append(self.unwrap(i))
|
||||||
|
return ret
|
||||||
|
if o[0] == "dict":
|
||||||
|
ret = {}
|
||||||
|
self.tag(index, ret)
|
||||||
|
for k,v in o[1]:
|
||||||
|
ret[self.unwrap(k)] = self.unwrap(v)
|
||||||
|
return ret
|
||||||
|
|
||||||
# generic object
|
# generic object
|
||||||
cls, newargs, state = o
|
cls, newargs, state = o
|
||||||
cls = _find_class(o[0])
|
cls = _find_class(o[0])
|
||||||
@ -116,23 +142,22 @@ class _Unpickler:
|
|||||||
inst = new_f(cls, *newargs)
|
inst = new_f(cls, *newargs)
|
||||||
else:
|
else:
|
||||||
inst = new_f(cls)
|
inst = new_f(cls)
|
||||||
|
self.tag(index, inst)
|
||||||
# restore state
|
# restore state
|
||||||
if hasattr(inst, "__setstate__"):
|
if state is not None:
|
||||||
inst.__setstate__(state)
|
for k,v in state.items():
|
||||||
else:
|
setattr(inst, k, self.unwrap(v))
|
||||||
if state is not None:
|
|
||||||
for k,v in state.items():
|
|
||||||
setattr(inst, k, self.unwrap(v))
|
|
||||||
return inst
|
return inst
|
||||||
|
|
||||||
|
def run_pipe(self):
|
||||||
|
return self.unwrap(self.obj)
|
||||||
|
|
||||||
|
|
||||||
def _wrap(o):
|
def _wrap(o):
|
||||||
p = _Pickler()
|
return _Pickler(o).run_pipe()
|
||||||
o = p.wrap(o)
|
|
||||||
return [o, p.memo]
|
|
||||||
|
|
||||||
def _unwrap(packed: list):
|
def _unwrap(packed: list):
|
||||||
o, memo = packed
|
return _Unpickler(*packed).run_pipe()
|
||||||
return _Unpickler(memo).unwrap(o)
|
|
||||||
|
|
||||||
def dumps(o) -> bytes:
|
def dumps(o) -> bytes:
|
||||||
o = _wrap(o)
|
o = _wrap(o)
|
||||||
|
|||||||
@ -4,9 +4,12 @@ def test(x):
|
|||||||
ok = x == loads(dumps(x))
|
ok = x == loads(dumps(x))
|
||||||
if not ok:
|
if not ok:
|
||||||
_0 = _wrap(x)
|
_0 = _wrap(x)
|
||||||
_1 = _unwrap(0)
|
_1 = _unwrap(_0)
|
||||||
|
print('='*50)
|
||||||
print(_0)
|
print(_0)
|
||||||
|
print('-'*50)
|
||||||
print(_1)
|
print(_1)
|
||||||
|
print('='*50)
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
test(1)
|
test(1)
|
||||||
@ -34,9 +37,7 @@ class Foo:
|
|||||||
return f"Foo({self.x}, {self.y})"
|
return f"Foo({self.x}, {self.y})"
|
||||||
|
|
||||||
test(Foo(1, 2))
|
test(Foo(1, 2))
|
||||||
|
test(Foo([1, True], 'c'))
|
||||||
a = [1,2]
|
|
||||||
test(Foo([1, 2], a))
|
|
||||||
|
|
||||||
from linalg import vec2
|
from linalg import vec2
|
||||||
|
|
||||||
@ -53,4 +54,10 @@ d = {'k': a, 'j': a}
|
|||||||
c = loads(dumps(d))
|
c = loads(dumps(d))
|
||||||
|
|
||||||
assert c['k'] is c['j']
|
assert c['k'] is c['j']
|
||||||
assert c == d
|
assert c == d
|
||||||
|
|
||||||
|
# test circular references
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
a = deque([1, 2, 3])
|
||||||
|
test(a)
|
||||||
Loading…
x
Reference in New Issue
Block a user