This commit is contained in:
blueloveTH 2023-06-13 00:18:21 +08:00
parent 8aba78c17f
commit 651a51dc49
2 changed files with 125 additions and 60 deletions

View File

@ -1,6 +1,8 @@
import json import json
import builtins import builtins
_BASIC_TYPES = [int, float, str, bool, type(None)]
def _find_class(path: str): def _find_class(path: str):
if "." not in path: if "." not in path:
g = globals() g = globals()
@ -16,47 +18,92 @@ def _find__new__(cls):
if "__new__" in d: if "__new__" in d:
return d["__new__"] return d["__new__"]
cls = cls.__base__ cls = cls.__base__
raise PickleError(f"cannot find __new__ for {cls.__name__}") assert False
def _wrap(o): class _Pickler:
if type(o) in (int, float, str, bool, type(None)): def __init__(self) -> None:
return o self.raw_memo = {} # id -> int
if type(o) is list: self.memo = [] # int -> object
return ["list", [_wrap(i) for i in o]]
if type(o) is tuple: def wrap(self, o):
return ["tuple", [_wrap(i) for i in o]] if type(o) in _BASIC_TYPES:
if type(o) is dict: return o
return ["dict", [[_wrap(k), _wrap(v)] for k,v in o.items()]]
if type(o) is bytes: index = self.raw_memo.get(id(o), None)
return ["bytes", [o[j] for j in range(len(o))]] if index is not None:
return ["$", index]
_0 = o.__class__.__name__
if hasattr(o, "__getnewargs__"): ret = []
_1 = o.__getnewargs__() # an iterable index = len(self.memo)
_1 = [_wrap(i) for i in _1] self.memo.append(ret)
else: self.raw_memo[id(o)] = index
_1 = None
if hasattr(o, "__getstate__"): if type(o) is list:
_2 = o.__getstate__() ret.append("list")
else: ret.append([self.wrap(i) for i in o])
if o.__dict__ is None: return ["$", index]
_2 = None
if type(o) is tuple:
ret.append("tuple")
ret.append([self.wrap(i) for i in o])
return ["$", index]
if type(o) is dict:
ret.append("dict")
ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
return ["$", index]
if type(o) is bytes:
ret.append("bytes")
ret.append([o[j] for j in range(len(o))])
return ["$", index]
_0 = o.__class__.__name__
if hasattr(o, "__getnewargs__"):
_1 = o.__getnewargs__() # an iterable
_1 = [self.wrap(i) for i in _1]
else: else:
_2 = {} _1 = None
for k,v in o.__dict__.items(): if hasattr(o, "__getstate__"):
_2[k] = _wrap(v) _2 = o.__getstate__()
return [_0, _1, _2] else:
if o.__dict__ is None:
_2 = None
else:
_2 = {}
for k,v in o.__dict__.items():
_2[k] = self.wrap(v)
ret.append(_0)
ret.append(_1)
ret.append(_2)
return ["$", index]
def _unwrap(o): class _Unpickler:
if type(o) in (int, float, str, bool, type(None)): def __init__(self, memo: list) -> None:
return o self.memo = memo
if isinstance(o, list): self.unwrapped = [None] * len(memo)
def unwrap_ref(self, i: int):
if self.unwrapped[i] is None:
o = self.memo[i]
assert type(o) is list
assert o[0] != '$'
self.unwrapped[i] = self.unwrap(o)
return self.unwrapped[i]
def unwrap(self, o):
if type(o) in _BASIC_TYPES:
return o
assert type(o) is list
if o[0] == '$':
index = o[1]
return self.unwrap_ref(index)
if o[0] == "list": if o[0] == "list":
return [_unwrap(i) for i in o[1]] return [self.unwrap(i) for i in o[1]]
if o[0] == "tuple": if o[0] == "tuple":
return tuple([_unwrap(i) for i in o[1]]) return tuple([self.unwrap(i) for i in o[1]])
if o[0] == "dict": if o[0] == "dict":
return {_unwrap(k): _unwrap(v) for k,v in o[1]} return {self.unwrap(k): self.unwrap(v) for k,v in o[1]}
if o[0] == "bytes": if o[0] == "bytes":
return bytes(o[1]) return bytes(o[1])
# generic object # generic object
@ -65,7 +112,7 @@ def _unwrap(o):
# create uninitialized instance # create uninitialized instance
new_f = _find__new__(cls) new_f = _find__new__(cls)
if newargs is not None: if newargs is not None:
newargs = [_unwrap(i) for i in newargs] newargs = [self.unwrap(i) for i in newargs]
inst = new_f(cls, *newargs) inst = new_f(cls, *newargs)
else: else:
inst = new_f(cls) inst = new_f(cls)
@ -75,14 +122,21 @@ def _unwrap(o):
else: else:
if state is not None: if state is not None:
for k,v in state.items(): for k,v in state.items():
setattr(inst, k, _unwrap(v)) setattr(inst, k, self.unwrap(v))
return inst return inst
raise PickleError(f"cannot unpickle {type(o).__name__} object")
def _wrap(o):
p = _Pickler()
o = p.wrap(o)
return [o, p.memo]
def _unwrap(packed: list):
o, memo = packed
return _Unpickler(memo).unwrap(o)
def dumps(o) -> bytes: def dumps(o) -> bytes:
return json.dumps(_wrap(o)).encode() o = _wrap(o)
return json.dumps(o).encode()
def loads(b) -> object: def loads(b) -> object:
assert type(b) is bytes assert type(b) is bytes

View File

@ -1,22 +1,24 @@
from pickle import dumps, loads, _wrap, _unwrap from pickle import dumps, loads, _wrap, _unwrap
def test(x, y): def test(x):
_0 = _wrap(x) ok = x == loads(dumps(x))
_1 = _unwrap(y) if not ok:
assert _0 == y, f"{_0} != {y}" _0 = _wrap(x)
assert _1 == x, f"{_1} != {x}" _1 = _unwrap(0)
assert x == loads(dumps(x)) print(_0)
print(_1)
assert False
test(1, 1) test(1)
test(1.0, 1.0) test(1.0)
test("hello", "hello") test("hello")
test(True, True) test(True)
test(False, False) test(False)
test(None, None) test(None)
test([1, 2, 3], ["list", [1, 2, 3]]) test([1, 2, 3])
test((1, 2, 3), ["tuple", [1, 2, 3]]) test((1, 2, 3))
test({1: 2, 3: 4}, ["dict", [[1, 2], [3, 4]]]) test({1: 2, 3: 4})
class Foo: class Foo:
def __init__(self, x, y): def __init__(self, x, y):
@ -31,15 +33,24 @@ class Foo:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Foo({self.x}, {self.y})" return f"Foo({self.x}, {self.y})"
foo = Foo(1, 2) test(Foo(1, 2))
test(foo, ["__main__.Foo", None, {"x": 1, "y": 2}])
a = [1,2]
test(Foo([1, 2], a))
from linalg import vec2 from linalg import vec2
test(vec2(1, 2), ["linalg.vec2", [1, 2], None]) test(vec2(1, 2))
a = {1, 2, 3, 4} a = {1, 2, 3, 4}
test(a, ['set', None, {'_a': ['dict', [[1, None], [2, None], [3, None], [4, None]]]}]) test(a)
a = bytes([1, 2, 3, 4]) a = bytes([1, 2, 3, 4])
assert loads(dumps(a)) == a test(a)
a = [1, 2]
d = {'k': a, 'j': a}
c = loads(dumps(d))
assert c['k'] is c['j']
assert c == d