refactor defaultdict

This commit is contained in:
blueloveTH 2024-03-28 19:21:56 +08:00
parent d7545071e5
commit c77fef35a2
8 changed files with 54 additions and 83 deletions

View File

@ -14,5 +14,4 @@ A double-ended queue.
### `collections.defaultdict`
A `dict` wrapper that calls a factory function to supply missing values.
It is not a subclass of `dict`.
A dictionary that returns a default value when a key is not found.

File diff suppressed because one or more lines are too long

View File

@ -224,6 +224,7 @@ const StrName __all__ = StrName::get("__all__");
const StrName __package__ = StrName::get("__package__");
const StrName __path__ = StrName::get("__path__");
const StrName __class__ = StrName::get("__class__");
const StrName __missing__ = StrName::get("__missing__");
const StrName pk_id_add = StrName::get("add");
const StrName pk_id_set = StrName::get("set");

View File

@ -82,7 +82,6 @@ struct PyTypeInfo{
void (*m__setattr__)(VM* vm, PyObject*, StrName, PyObject*) = nullptr;
PyObject* (*m__getattr__)(VM* vm, PyObject*, StrName) = nullptr;
bool (*m__delattr__)(VM* vm, PyObject*, StrName) = nullptr;
};
typedef void(*PrintFunc)(const char*, int);

View File

@ -7,63 +7,19 @@ def Counter(iterable):
a[x] = 1
return a
class defaultdict:
def __init__(self, default_factory) -> None:
class defaultdict(dict):
def __init__(self, default_factory, *args):
super().__init__(*args)
self._enable_instance_dict()
self.default_factory = default_factory
self._a = {}
def __getitem__(self, key):
if key not in self._a:
self._a[key] = self.default_factory()
return self._a[key]
def __setitem__(self, key, value):
self._a[key] = value
def __delitem__(self, key):
del self._a[key]
def __missing__(self, key):
self[key] = self.default_factory()
return self[key]
def __repr__(self) -> str:
return f"defaultdict({self.default_factory}, {self._a})"
def __eq__(self, __o: object) -> bool:
if not isinstance(__o, defaultdict):
return False
if self.default_factory != __o.default_factory:
return False
return self._a == __o._a
def __iter__(self):
return iter(self._a)
def __contains__(self, key):
return key in self._a
def __len__(self):
return len(self._a)
def keys(self):
return self._a.keys()
def values(self):
return self._a.values()
def items(self):
return self._a.items()
def pop(self, *args):
return self._a.pop(*args)
def clear(self):
self._a.clear()
return f"defaultdict({self.default_factory}, {super().__repr__()})"
def copy(self):
new_dd = defaultdict(self.default_factory)
new_dd._a = self._a.copy()
return new_dd
return defaultdict(self.default_factory, self)
def get(self, key, default):
return self._a.get(key, default)
def update(self, other):
self._a.update(other)

View File

@ -1228,15 +1228,22 @@ void init_builtins(VM* _vm) {
// tp_dict
_vm->bind_constructor<-1>(_vm->_t(VM::tp_dict), [](VM* vm, ArgsView args){
return VAR(Dict(vm));
Type cls_t = PK_OBJ_GET(Type, args[0]);
return vm->heap.gcnew<Dict>(cls_t, vm);
});
_vm->bind_method<-1>(VM::tp_dict, "__init__", [](VM* vm, ArgsView args){
if(args.size() == 1+0) return vm->None;
if(args.size() == 1+1){
auto _lock = vm->heap.gc_scope_lock();
Dict& self = _CAST(Dict&, args[0]);
List& list = CAST(List&, args[1]);
Dict& self = PK_OBJ_GET(Dict, args[0]);
if(is_non_tagged_type(args[1], vm->tp_dict)){
Dict& other = CAST(Dict&, args[1]);
self.update(other);
return vm->None;
}
if(is_non_tagged_type(args[1], vm->tp_list)){
List& list = PK_OBJ_GET(List, args[1]);
for(PyObject* item : list){
Tuple& t = CAST(Tuple&, item);
if(t.size() != 2){
@ -1245,20 +1252,29 @@ void init_builtins(VM* _vm) {
}
self.set(t[0], t[1]);
}
}
return vm->None;
}
vm->TypeError("dict() takes at most 1 argument");
return vm->None;
PK_UNREACHABLE()
});
_vm->bind__len__(VM::tp_dict, [](VM* vm, PyObject* _0) {
return (i64)_CAST(Dict&, _0).size();
return (i64)PK_OBJ_GET(Dict, _0).size();
});
_vm->bind__getitem__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
Dict& self = _CAST(Dict&, _0);
Dict& self = PK_OBJ_GET(Dict, _0);
PyObject* ret = self.try_get(_1);
if(ret == nullptr) vm->KeyError(_1);
if(ret == nullptr){
// try __missing__
PyObject* self;
PyObject* f_missing = vm->get_unbound_method(_0, __missing__, &self, false);
if(f_missing != nullptr){
return vm->call_method(self, f_missing, _1);
}
vm->KeyError(_1);
}
return ret;
});
@ -1372,7 +1388,7 @@ void init_builtins(VM* _vm) {
_vm->bind__eq__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
Dict& self = _CAST(Dict&, _0);
if(!is_non_tagged_type(_1, vm->tp_dict)) return vm->NotImplemented;
if(!vm->isinstance(_1, vm->tp_dict)) return vm->NotImplemented;
Dict& other = _CAST(Dict&, _1);
if(self.size() != other.size()) return vm->False;
for(int i=0; i<self._capacity; i++){

View File

@ -714,7 +714,7 @@ void VM::init_builtin_types(){
if(tp_exception != _new_type_object("Exception", 0, true)) exit(-3);
if(tp_bytes != _new_type_object("bytes")) exit(-3);
if(tp_mappingproxy != _new_type_object("mappingproxy")) exit(-3);
if(tp_dict != _new_type_object("dict")) exit(-3);
if(tp_dict != _new_type_object("dict", 0, true)) exit(-3); // dict can be subclassed
if(tp_property != _new_type_object("property")) exit(-3);
if(tp_star_wrapper != _new_type_object("_star_wrapper")) exit(-3);
@ -1301,7 +1301,7 @@ void VM::bind__hash__(Type type, i64 (*f)(VM*, PyObject*)){
PyObject* obj = _t(type);
_all_types[type].m__hash__ = f;
PyObject* nf = bind_method<0>(obj, "__hash__", [](VM* vm, ArgsView args){
i64 ret = lambda_get_userdata<i64(*)(VM*, PyObject*)>(args.begin())(vm, args[0]);
i64 ret = lambda_get_userdata<decltype(f)>(args.begin())(vm, args[0]);
return VAR(ret);
});
PK_OBJ_GET(NativeFunc, nf).set_userdata(f);
@ -1311,7 +1311,7 @@ void VM::bind__len__(Type type, i64 (*f)(VM*, PyObject*)){
PyObject* obj = _t(type);
_all_types[type].m__len__ = f;
PyObject* nf = bind_method<0>(obj, "__len__", [](VM* vm, ArgsView args){
i64 ret = lambda_get_userdata<i64(*)(VM*, PyObject*)>(args.begin())(vm, args[0]);
i64 ret = lambda_get_userdata<decltype(f)>(args.begin())(vm, args[0]);
return VAR(ret);
});
PK_OBJ_GET(NativeFunc, nf).set_userdata(f);

View File

@ -2,15 +2,15 @@ from collections import Counter, deque, defaultdict
import random
import pickle
import gc
import builtins
dd_dict_keys = sorted(defaultdict.__dict__.keys())
d_dict_keys = sorted(dict.__dict__.keys())
d_dict_keys.remove('__new__')
if dd_dict_keys != d_dict_keys:
print("dd_dict_keys:", dd_dict_keys)
print("d_dict_keys:", d_dict_keys)
raise Exception("dd_dict_keys != d_dict_keys")
# test defaultdict
assert issubclass(defaultdict, dict)
a = defaultdict(int)
a['1'] += 1
assert a == {'1': 1}
a = defaultdict(list)
a['1'].append(1)
assert a == {'1': [1]}
q = deque()
q.append(1)