pickle supports cyclic references

This commit is contained in:
blueloveTH 2023-06-13 23:31:09 +08:00
parent 81feb7a245
commit 6874141b62
2 changed files with 91 additions and 59 deletions

View File

@ -21,7 +21,8 @@ def _find__new__(cls):
assert False
class _Pickler:
def __init__(self) -> None:
def __init__(self, obj) -> None:
self.obj = obj
self.raw_memo = {} # id -> int
self.memo = [] # int -> object
@ -31,81 +32,106 @@ class _Pickler:
index = self.raw_memo.get(id(o), None)
if index is not None:
return ["$", index]
return [index]
ret = []
index = len(self.memo)
self.memo.append(ret)
self.raw_memo[id(o)] = index
if type(o) is list:
ret.append("list")
ret.append([self.wrap(i) for i in o])
return ["$", index]
if type(o) is tuple:
ret.append("tuple")
ret.append([self.wrap(i) for i in o])
return ["$", index]
if type(o) is dict:
ret.append("dict")
ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
return ["$", index]
return [index]
if type(o) is bytes:
ret.append("bytes")
ret.append([o[j] for j in range(len(o))])
return ["$", index]
return [index]
if type(o) is list:
ret.append("list")
ret.append([self.wrap(i) for i in o])
return [index]
if type(o) is dict:
ret.append("dict")
ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
return [index]
_0 = o.__class__.__name__
if hasattr(o, "__getnewargs__"):
_1 = o.__getnewargs__() # an iterable
_1 = [self.wrap(i) for i in _1]
else:
_1 = None
if hasattr(o, "__getstate__"):
_2 = o.__getstate__()
if o.__dict__ is None:
_2 = None
else:
if o.__dict__ is None:
_2 = None
else:
_2 = {}
for k,v in o.__dict__.items():
_2[k] = self.wrap(v)
_2 = {}
for k,v in o.__dict__.items():
_2[k] = self.wrap(v)
ret.append(_0)
ret.append(_1)
ret.append(_2)
return ["$", index]
return [index]
def run_pipe(self):
o = self.wrap(self.obj)
return [o, self.memo]
class _Unpickler:
def __init__(self, memo: list) -> None:
def __init__(self, obj, memo: list) -> None:
self.obj = obj
self.memo = memo
self.unwrapped = [None] * len(memo)
self._unwrapped = [None] * len(memo)
def unwrap_ref(self, i: int):
if self.unwrapped[i] is None:
o = self.memo[i]
assert type(o) is list
assert o[0] != '$'
self.unwrapped[i] = self.unwrap(o)
return self.unwrapped[i]
def tag(self, index, o):
assert self._unwrapped[index] is None
self._unwrapped[index] = o
def unwrap(self, o):
def unwrap(self, o, index=None):
if type(o) in _BASIC_TYPES:
return o
assert type(o) is list
if o[0] == '$':
index = o[1]
return self.unwrap_ref(index)
if o[0] == "list":
return [self.unwrap(i) for i in o[1]]
# reference
if type(o[0]) is int:
assert index is None # index should be None
index = o[0]
if self._unwrapped[index] is None:
o = self.memo[index]
assert type(o) is list
assert type(o[0]) is str
self.unwrap(o, index)
assert self._unwrapped[index] is not None
return self._unwrapped[index]
# concrete reference type
if o[0] == "tuple":
return tuple([self.unwrap(i) for i in o[1]])
if o[0] == "dict":
return {self.unwrap(k): self.unwrap(v) for k,v in o[1]}
ret = tuple([self.unwrap(i) for i in o[1]])
self.tag(index, ret)
return ret
if o[0] == "bytes":
return bytes(o[1])
ret = bytes(o[1])
self.tag(index, ret)
return ret
if o[0] == "list":
ret = []
self.tag(index, ret)
for i in o[1]:
ret.append(self.unwrap(i))
return ret
if o[0] == "dict":
ret = {}
self.tag(index, ret)
for k,v in o[1]:
ret[self.unwrap(k)] = self.unwrap(v)
return ret
# generic object
cls, newargs, state = o
cls = _find_class(o[0])
@ -116,23 +142,22 @@ class _Unpickler:
inst = new_f(cls, *newargs)
else:
inst = new_f(cls)
self.tag(index, inst)
# restore state
if hasattr(inst, "__setstate__"):
inst.__setstate__(state)
else:
if state is not None:
for k,v in state.items():
setattr(inst, k, self.unwrap(v))
if state is not None:
for k,v in state.items():
setattr(inst, k, self.unwrap(v))
return inst
def run_pipe(self):
return self.unwrap(self.obj)
def _wrap(o):
p = _Pickler()
o = p.wrap(o)
return [o, p.memo]
return _Pickler(o).run_pipe()
def _unwrap(packed: list):
o, memo = packed
return _Unpickler(memo).unwrap(o)
return _Unpickler(*packed).run_pipe()
def dumps(o) -> bytes:
o = _wrap(o)

View File

@ -4,9 +4,12 @@ def test(x):
ok = x == loads(dumps(x))
if not ok:
_0 = _wrap(x)
_1 = _unwrap(0)
_1 = _unwrap(_0)
print('='*50)
print(_0)
print('-'*50)
print(_1)
print('='*50)
assert False
test(1)
@ -34,9 +37,7 @@ class Foo:
return f"Foo({self.x}, {self.y})"
test(Foo(1, 2))
a = [1,2]
test(Foo([1, 2], a))
test(Foo([1, True], 'c'))
from linalg import vec2
@ -53,4 +54,10 @@ d = {'k': a, 'j': a}
c = loads(dumps(d))
assert c['k'] is c['j']
assert c == d
assert c == d
# test circular references
from collections import deque
a = deque([1, 2, 3])
test(a)