From 7bd99279e52cf6ffa4a4a9d1be126d5a99a058d9 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Sun, 4 Feb 2024 17:53:31 +0800 Subject: [PATCH] fix https://github.com/pocketpy/pocketpy/issues/196 --- amalgamate.py | 2 +- include/pocketpy/dataclasses.h | 9 +++ include/pocketpy/pocketpy.h | 1 + src/compiler.cpp | 1 - src/dataclasses.cpp | 117 +++++++++++++++++++++++++++++++++ src/pocketpy.cpp | 10 +-- 6 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 include/pocketpy/dataclasses.h create mode 100644 src/dataclasses.cpp diff --git a/amalgamate.py b/amalgamate.py index 176eafa2..cc9e553b 100644 --- a/amalgamate.py +++ b/amalgamate.py @@ -9,7 +9,7 @@ pipeline = [ ["config.h", "export.h", "common.h", "memory.h", "vector.h", "str.h", "tuplelist.h", "namedict.h", "error.h"], ["obj.h", "dict.h", "codeobject.h", "frame.h"], ["gc.h", "vm.h", "ceval.h", "lexer.h", "expr.h", "compiler.h", "repl.h"], - ["_generated.h", "cffi.h", "bindings.h", "iter.h", "base64.h", "csv.h", "collections.h", "random.h", "linalg.h", "easing.h", "io.h", "modules.h"], + ["_generated.h", "cffi.h", "bindings.h", "iter.h", "base64.h", "csv.h", "collections.h", "dataclasses.h", "random.h", "linalg.h", "easing.h", "io.h", "modules.h"], ["pocketpy.h", "pocketpy_c.h"] ] diff --git a/include/pocketpy/dataclasses.h b/include/pocketpy/dataclasses.h new file mode 100644 index 00000000..e45b7c25 --- /dev/null +++ b/include/pocketpy/dataclasses.h @@ -0,0 +1,9 @@ +#pragma once + +#include "cffi.h" + +namespace pkpy{ + +void add_module_dataclasses(VM* vm); + +} // namespace pkpy \ No newline at end of file diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 9eea4de0..6ea88c21 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -14,4 +14,5 @@ #include "bindings.h" #include "collections.h" #include "csv.h" +#include "dataclasses.h" #include "modules.h" diff --git a/src/compiler.cpp b/src/compiler.cpp index 39bc1706..72c26f04 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -944,7 +944,6 @@ __EAT_DOTS_END: is_typed_name = true; if(ctx()->is_compiling_class){ - // add to __annotations__ NameExpr* ne = static_cast(ctx()->s_expr.top().get()); ctx()->emit_(OP_ADD_CLASS_ANNOTATION, ne->name.index, BC_KEEPLINE); } diff --git a/src/dataclasses.cpp b/src/dataclasses.cpp new file mode 100644 index 00000000..f85bbb3f --- /dev/null +++ b/src/dataclasses.cpp @@ -0,0 +1,117 @@ +#include "pocketpy/dataclasses.h" + +namespace pkpy{ + +static void patch__init__(VM* vm, Type cls){ + vm->bind(vm->_t(cls), "__init__(self, *args, **kwargs)", [](VM* vm, ArgsView _view){ + PyObject* self = _view[0]; + const Tuple& args = CAST(Tuple&, _view[1]); + const Dict& kwargs_ = CAST(Dict&, _view[2]); + NameDict kwargs; + kwargs_.apply([&](PyObject* k, PyObject* v){ + kwargs.set(CAST(Str&, k), v); + }); + + Type cls = vm->_tp(self); + const PyTypeInfo* cls_info = &vm->_all_types[cls]; + NameDict& cls_d = cls_info->obj->attr(); + const auto& fields = cls_info->annotated_fields; + + int i = 0; // index into args + for(StrName field: fields){ + if(kwargs.contains(field)){ + self->attr().set(field, kwargs[field]); + kwargs.del(field); + }else{ + if(i < args.size()){ + self->attr().set(field, args[i]); + ++i; + }else if(cls_d.contains(field)){ // has default value + self->attr().set(field, cls_d[field]); + }else{ + vm->TypeError(_S(cls_info->name, " missing required argument ", field.escape())); + } + } + } + if(args.size() > i){ + vm->TypeError(_S(cls_info->name, " takes ", fields.size(), " positional arguments but ", args.size(), " were given")); + } + if(kwargs.size() > 0){ + StrName unexpected_key = kwargs.items()[0].first; + vm->TypeError(_S(cls_info->name, " got an unexpected keyword argument ", unexpected_key.escape())); + } + return vm->None; + }); +} + +static void patch__repr__(VM* vm, Type cls){ + vm->bind__repr__(cls, [](VM* vm, PyObject* _0){ + auto _lock = vm->heap.gc_scope_lock(); + const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)]; + const auto& fields = cls_info->annotated_fields; + const NameDict& obj_d = _0->attr(); + SStream ss; + ss << cls_info->name << "("; + bool first = true; + for(StrName field: fields){ + if(first) first = false; + else ss << ", "; + ss << field << "=" << CAST(Str&, vm->py_repr(obj_d[field])); + } + ss << ")"; + return VAR(ss.str()); + }); +} + +static void patch__eq__(VM* vm, Type cls){ + vm->bind__eq__(cls, [](VM* vm, PyObject* _0, PyObject* _1){ + if(vm->_tp(_0) != vm->_tp(_1)) return vm->NotImplemented; + const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)]; + const auto& fields = cls_info->annotated_fields; + for(StrName field: fields){ + PyObject* lhs = _0->attr(field); + PyObject* rhs = _1->attr(field); + if(vm->py_ne(lhs, rhs)) return vm->False; + } + return vm->True; + }); +} + +void add_module_dataclasses(VM* vm){ + PyObject* mod = vm->new_module("dataclasses"); + + vm->bind_func<1>(mod, "dataclass", [](VM* vm, ArgsView args){ + vm->check_non_tagged_type(args[0], VM::tp_type); + Type cls = PK_OBJ_GET(Type, args[0]); + NameDict& cls_d = args[0]->attr(); + + if(!cls_d.contains("__init__")) patch__init__(vm, cls); + if(!cls_d.contains("__repr__")) patch__repr__(vm, cls); + if(!cls_d.contains("__eq__")) patch__eq__(vm, cls); + + const auto& fields = vm->_all_types[cls].annotated_fields; + bool has_default = false; + for(StrName field: fields){ + if(cls_d.contains(field)){ + has_default = true; + }else{ + if(has_default){ + vm->TypeError(_S("non-default argument ", field.escape(), " follows default argument")); + } + } + } + return args[0]; + }); + + vm->bind_func<1>(mod, "asdict", [](VM* vm, ArgsView args){ + const auto& fields = vm->_inst_type_info(args[0])->annotated_fields; + const NameDict& obj_d = args[0]->attr(); + Dict d(vm); + for(StrName field: fields){ + d.set(VAR(field.sv()), obj_d[field]); + } + return VAR(std::move(d)); + }); +} + +} // namespace pkpy \ No newline at end of file diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index 0e5fee46..b7fa059b 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -1380,15 +1380,6 @@ void VM::post_init(){ return self; // for generics }); - bind_property(_t(tp_type), "__annotations__", [](VM* vm, ArgsView args){ - const PyTypeInfo* ti = &vm->_all_types[(PK_OBJ_GET(Type, args[0]))]; - Tuple t(ti->annotated_fields.size()); - for(int i=0; iannotated_fields.size(); i++){ - t[i] = VAR(ti->annotated_fields[i].sv()); - } - return VAR(std::move(t)); - }); - bind__repr__(tp_type, [](VM* vm, PyObject* self){ SStream ss; const PyTypeInfo& info = vm->_all_types[PK_OBJ_GET(Type, self)]; @@ -1451,6 +1442,7 @@ void VM::post_init(){ add_module_base64(this); add_module_operator(this); add_module_csv(this); + add_module_dataclasses(this); for(const char* name: {"this", "functools", "heapq", "bisect", "pickle", "_long", "colorsys", "typing", "datetime", "dataclasses", "cmath"}){ _lazy_modules[name] = kPythonLibs[name];