From 651a51dc49b39befec17d32c877dd2e567dc7461 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 13 Jun 2023 00:18:21 +0800 Subject: [PATCH] ... --- python/pickle.py | 134 +++++++++++++++++++++++++++++++-------------- tests/81_pickle.py | 51 ++++++++++------- 2 files changed, 125 insertions(+), 60 deletions(-) diff --git a/python/pickle.py b/python/pickle.py index 3a3a65f2..f5101d78 100644 --- a/python/pickle.py +++ b/python/pickle.py @@ -1,6 +1,8 @@ import json import builtins +_BASIC_TYPES = [int, float, str, bool, type(None)] + def _find_class(path: str): if "." not in path: g = globals() @@ -16,47 +18,92 @@ def _find__new__(cls): if "__new__" in d: return d["__new__"] cls = cls.__base__ - raise PickleError(f"cannot find __new__ for {cls.__name__}") + assert False -def _wrap(o): - if type(o) in (int, float, str, bool, type(None)): - return o - if type(o) is list: - return ["list", [_wrap(i) for i in o]] - if type(o) is tuple: - return ["tuple", [_wrap(i) for i in o]] - if type(o) is dict: - return ["dict", [[_wrap(k), _wrap(v)] for k,v in o.items()]] - if type(o) is bytes: - return ["bytes", [o[j] for j in range(len(o))]] - - _0 = o.__class__.__name__ - if hasattr(o, "__getnewargs__"): - _1 = o.__getnewargs__() # an iterable - _1 = [_wrap(i) for i in _1] - else: - _1 = None - if hasattr(o, "__getstate__"): - _2 = o.__getstate__() - else: - if o.__dict__ is None: - _2 = None +class _Pickler: + def __init__(self) -> None: + self.raw_memo = {} # id -> int + self.memo = [] # int -> object + + def wrap(self, o): + if type(o) in _BASIC_TYPES: + return o + + index = self.raw_memo.get(id(o), None) + if index is not None: + return ["$", index] + + ret = [] + index = len(self.memo) + self.memo.append(ret) + self.raw_memo[id(o)] = index + + if type(o) is list: + ret.append("list") + ret.append([self.wrap(i) for i in o]) + return ["$", index] + + if type(o) is tuple: + ret.append("tuple") + ret.append([self.wrap(i) for i in o]) + return ["$", index] + + if type(o) is dict: + ret.append("dict") + ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()]) + return ["$", index] + + if type(o) is bytes: + ret.append("bytes") + ret.append([o[j] for j in range(len(o))]) + return ["$", index] + + _0 = o.__class__.__name__ + if hasattr(o, "__getnewargs__"): + _1 = o.__getnewargs__() # an iterable + _1 = [self.wrap(i) for i in _1] else: - _2 = {} - for k,v in o.__dict__.items(): - _2[k] = _wrap(v) - return [_0, _1, _2] + _1 = None + if hasattr(o, "__getstate__"): + _2 = o.__getstate__() + else: + if o.__dict__ is None: + _2 = None + else: + _2 = {} + for k,v in o.__dict__.items(): + _2[k] = self.wrap(v) + ret.append(_0) + ret.append(_1) + ret.append(_2) + return ["$", index] -def _unwrap(o): - if type(o) in (int, float, str, bool, type(None)): - return o - if isinstance(o, list): +class _Unpickler: + def __init__(self, memo: list) -> None: + self.memo = memo + self.unwrapped = [None] * len(memo) + + def unwrap_ref(self, i: int): + if self.unwrapped[i] is None: + o = self.memo[i] + assert type(o) is list + assert o[0] != '$' + self.unwrapped[i] = self.unwrap(o) + return self.unwrapped[i] + + def unwrap(self, o): + if type(o) in _BASIC_TYPES: + return o + assert type(o) is list + if o[0] == '$': + index = o[1] + return self.unwrap_ref(index) if o[0] == "list": - return [_unwrap(i) for i in o[1]] + return [self.unwrap(i) for i in o[1]] if o[0] == "tuple": - return tuple([_unwrap(i) for i in o[1]]) + return tuple([self.unwrap(i) for i in o[1]]) if o[0] == "dict": - return {_unwrap(k): _unwrap(v) for k,v in o[1]} + return {self.unwrap(k): self.unwrap(v) for k,v in o[1]} if o[0] == "bytes": return bytes(o[1]) # generic object @@ -65,7 +112,7 @@ def _unwrap(o): # create uninitialized instance new_f = _find__new__(cls) if newargs is not None: - newargs = [_unwrap(i) for i in newargs] + newargs = [self.unwrap(i) for i in newargs] inst = new_f(cls, *newargs) else: inst = new_f(cls) @@ -75,14 +122,21 @@ def _unwrap(o): else: if state is not None: for k,v in state.items(): - setattr(inst, k, _unwrap(v)) + setattr(inst, k, self.unwrap(v)) return inst - raise PickleError(f"cannot unpickle {type(o).__name__} object") +def _wrap(o): + p = _Pickler() + o = p.wrap(o) + return [o, p.memo] + +def _unwrap(packed: list): + o, memo = packed + return _Unpickler(memo).unwrap(o) def dumps(o) -> bytes: - return json.dumps(_wrap(o)).encode() - + o = _wrap(o) + return json.dumps(o).encode() def loads(b) -> object: assert type(b) is bytes diff --git a/tests/81_pickle.py b/tests/81_pickle.py index baba8783..8238af6f 100644 --- a/tests/81_pickle.py +++ b/tests/81_pickle.py @@ -1,22 +1,24 @@ from pickle import dumps, loads, _wrap, _unwrap -def test(x, y): - _0 = _wrap(x) - _1 = _unwrap(y) - assert _0 == y, f"{_0} != {y}" - assert _1 == x, f"{_1} != {x}" - assert x == loads(dumps(x)) +def test(x): + ok = x == loads(dumps(x)) + if not ok: + _0 = _wrap(x) + _1 = _unwrap(0) + print(_0) + print(_1) + assert False -test(1, 1) -test(1.0, 1.0) -test("hello", "hello") -test(True, True) -test(False, False) -test(None, None) +test(1) +test(1.0) +test("hello") +test(True) +test(False) +test(None) -test([1, 2, 3], ["list", [1, 2, 3]]) -test((1, 2, 3), ["tuple", [1, 2, 3]]) -test({1: 2, 3: 4}, ["dict", [[1, 2], [3, 4]]]) +test([1, 2, 3]) +test((1, 2, 3)) +test({1: 2, 3: 4}) class Foo: def __init__(self, x, y): @@ -31,15 +33,24 @@ class Foo: def __repr__(self) -> str: return f"Foo({self.x}, {self.y})" -foo = Foo(1, 2) -test(foo, ["__main__.Foo", None, {"x": 1, "y": 2}]) +test(Foo(1, 2)) + +a = [1,2] +test(Foo([1, 2], a)) from linalg import vec2 -test(vec2(1, 2), ["linalg.vec2", [1, 2], None]) +test(vec2(1, 2)) a = {1, 2, 3, 4} -test(a, ['set', None, {'_a': ['dict', [[1, None], [2, None], [3, None], [4, None]]]}]) +test(a) a = bytes([1, 2, 3, 4]) -assert loads(dumps(a)) == a \ No newline at end of file +test(a) + +a = [1, 2] +d = {'k': a, 'j': a} +c = loads(dumps(d)) + +assert c['k'] is c['j'] +assert c == d \ No newline at end of file