diff --git a/python/pickle.py b/python/pickle.py index f5101d78..f12277a1 100644 --- a/python/pickle.py +++ b/python/pickle.py @@ -21,7 +21,8 @@ def _find__new__(cls): assert False class _Pickler: - def __init__(self) -> None: + def __init__(self, obj) -> None: + self.obj = obj self.raw_memo = {} # id -> int self.memo = [] # int -> object @@ -31,81 +32,106 @@ class _Pickler: index = self.raw_memo.get(id(o), None) if index is not None: - return ["$", index] + return [index] ret = [] index = len(self.memo) self.memo.append(ret) self.raw_memo[id(o)] = index - if type(o) is list: - ret.append("list") - ret.append([self.wrap(i) for i in o]) - return ["$", index] - if type(o) is tuple: ret.append("tuple") ret.append([self.wrap(i) for i in o]) - return ["$", index] - - if type(o) is dict: - ret.append("dict") - ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()]) - return ["$", index] - + return [index] if type(o) is bytes: ret.append("bytes") ret.append([o[j] for j in range(len(o))]) - return ["$", index] + return [index] + if type(o) is list: + ret.append("list") + ret.append([self.wrap(i) for i in o]) + return [index] + if type(o) is dict: + ret.append("dict") + ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()]) + return [index] _0 = o.__class__.__name__ + if hasattr(o, "__getnewargs__"): _1 = o.__getnewargs__() # an iterable _1 = [self.wrap(i) for i in _1] else: _1 = None - if hasattr(o, "__getstate__"): - _2 = o.__getstate__() + + if o.__dict__ is None: + _2 = None else: - if o.__dict__ is None: - _2 = None - else: - _2 = {} - for k,v in o.__dict__.items(): - _2[k] = self.wrap(v) + _2 = {} + for k,v in o.__dict__.items(): + _2[k] = self.wrap(v) + ret.append(_0) ret.append(_1) ret.append(_2) - return ["$", index] + return [index] + + def run_pipe(self): + o = self.wrap(self.obj) + return [o, self.memo] + + class _Unpickler: - def __init__(self, memo: list) -> None: + def __init__(self, obj, memo: list) -> None: + self.obj = obj self.memo = memo - self.unwrapped = [None] * len(memo) + self._unwrapped = [None] * len(memo) - def unwrap_ref(self, i: int): - if self.unwrapped[i] is None: - o = self.memo[i] - assert type(o) is list - assert o[0] != '$' - self.unwrapped[i] = self.unwrap(o) - return self.unwrapped[i] + def tag(self, index, o): + assert self._unwrapped[index] is None + self._unwrapped[index] = o - def unwrap(self, o): + def unwrap(self, o, index=None): if type(o) in _BASIC_TYPES: return o assert type(o) is list - if o[0] == '$': - index = o[1] - return self.unwrap_ref(index) - if o[0] == "list": - return [self.unwrap(i) for i in o[1]] + + # reference + if type(o[0]) is int: + assert index is None # index should be None + index = o[0] + if self._unwrapped[index] is None: + o = self.memo[index] + assert type(o) is list + assert type(o[0]) is str + self.unwrap(o, index) + assert self._unwrapped[index] is not None + return self._unwrapped[index] + + # concrete reference type if o[0] == "tuple": - return tuple([self.unwrap(i) for i in o[1]]) - if o[0] == "dict": - return {self.unwrap(k): self.unwrap(v) for k,v in o[1]} + ret = tuple([self.unwrap(i) for i in o[1]]) + self.tag(index, ret) + return ret if o[0] == "bytes": - return bytes(o[1]) + ret = bytes(o[1]) + self.tag(index, ret) + return ret + + if o[0] == "list": + ret = [] + self.tag(index, ret) + for i in o[1]: + ret.append(self.unwrap(i)) + return ret + if o[0] == "dict": + ret = {} + self.tag(index, ret) + for k,v in o[1]: + ret[self.unwrap(k)] = self.unwrap(v) + return ret + # generic object cls, newargs, state = o cls = _find_class(o[0]) @@ -116,23 +142,22 @@ class _Unpickler: inst = new_f(cls, *newargs) else: inst = new_f(cls) + self.tag(index, inst) # restore state - if hasattr(inst, "__setstate__"): - inst.__setstate__(state) - else: - if state is not None: - for k,v in state.items(): - setattr(inst, k, self.unwrap(v)) + if state is not None: + for k,v in state.items(): + setattr(inst, k, self.unwrap(v)) return inst + def run_pipe(self): + return self.unwrap(self.obj) + + def _wrap(o): - p = _Pickler() - o = p.wrap(o) - return [o, p.memo] + return _Pickler(o).run_pipe() def _unwrap(packed: list): - o, memo = packed - return _Unpickler(memo).unwrap(o) + return _Unpickler(*packed).run_pipe() def dumps(o) -> bytes: o = _wrap(o) diff --git a/tests/81_pickle.py b/tests/81_pickle.py index 8238af6f..6be4122b 100644 --- a/tests/81_pickle.py +++ b/tests/81_pickle.py @@ -4,9 +4,12 @@ def test(x): ok = x == loads(dumps(x)) if not ok: _0 = _wrap(x) - _1 = _unwrap(0) + _1 = _unwrap(_0) + print('='*50) print(_0) + print('-'*50) print(_1) + print('='*50) assert False test(1) @@ -34,9 +37,7 @@ class Foo: return f"Foo({self.x}, {self.y})" test(Foo(1, 2)) - -a = [1,2] -test(Foo([1, 2], a)) +test(Foo([1, True], 'c')) from linalg import vec2 @@ -53,4 +54,10 @@ d = {'k': a, 'j': a} c = loads(dumps(d)) assert c['k'] is c['j'] -assert c == d \ No newline at end of file +assert c == d + +# test circular references +from collections import deque + +a = deque([1, 2, 3]) +test(a) \ No newline at end of file