From 475bce9999623c64f6d6885b03d5acb184b1d978 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Thu, 21 Dec 2023 23:09:23 +0800 Subject: [PATCH] add `@dataclass` --- docs/modules/dataclasses.md | 14 +++++++++ include/pocketpy/cffi.h | 1 + include/pocketpy/common.h | 4 +-- include/pocketpy/compiler.h | 3 +- include/pocketpy/opcodes.h | 3 ++ include/pocketpy/vm.h | 2 ++ python/builtins.py | 11 +++++++ python/dataclasses.py | 59 +++++++++++++++++++++++++++++++++++++ src/ceval.cpp | 18 +++++++++-- src/cffi.cpp | 5 ++-- src/compiler.cpp | 47 +++++++++++++++++++++-------- src/pocketpy.cpp | 34 +++++++-------------- src/vm.cpp | 1 - tests/82_dataclasses.py | 29 ++++++++++++++++++ tests/99_builtin_func.py | 11 ++++++- 15 files changed, 197 insertions(+), 45 deletions(-) create mode 100644 docs/modules/dataclasses.md create mode 100644 python/dataclasses.py create mode 100644 tests/82_dataclasses.py diff --git a/docs/modules/dataclasses.md b/docs/modules/dataclasses.md new file mode 100644 index 00000000..c6372233 --- /dev/null +++ b/docs/modules/dataclasses.md @@ -0,0 +1,14 @@ +--- +icon: package +label: dataclasses +--- + +### `dataclasses.dataclass` + +A decorator that is used to add generated special method to classes, including `__init__`, `__repr__` and `__eq__`. + +### `dataclasses.asdict(obj) -> dict` + +Convert a dataclass instance to a dictionary. + + diff --git a/include/pocketpy/cffi.h b/include/pocketpy/cffi.h index c68749d4..a2937b06 100644 --- a/include/pocketpy/cffi.h +++ b/include/pocketpy/cffi.h @@ -22,6 +22,7 @@ namespace pkpy { throw std::runtime_error(msg.str()); \ } \ PyObject* type = vm->new_type_object(mod, #name, base); \ + mod->attr().set(#name, type); \ T::_register(vm, mod, type); \ return type; \ } diff --git a/include/pocketpy/common.h b/include/pocketpy/common.h index d897e66b..85172aab 100644 --- a/include/pocketpy/common.h +++ b/include/pocketpy/common.h @@ -23,7 +23,7 @@ #include #include -#define PK_VERSION "1.3.3" +#define PK_VERSION "1.3.5" #include "config.h" #include "export.h" @@ -166,7 +166,7 @@ static_assert(sizeof(Number::int_t) == sizeof(void*)); static_assert(sizeof(BitsCvt) == sizeof(void*)); static_assert(std::numeric_limits::is_iec559); -struct Dummy { }; +struct Dummy { }; // for special objects: True, False, None, Ellipsis, etc. struct DummyInstance { }; struct DummyModule { }; struct NoReturn { }; diff --git a/include/pocketpy/compiler.h b/include/pocketpy/compiler.h index 30dd7fc5..e7b39057 100644 --- a/include/pocketpy/compiler.h +++ b/include/pocketpy/compiler.h @@ -118,7 +118,8 @@ class Compiler { bool try_compile_assignment(); void compile_stmt(); void consume_type_hints(); - void compile_class(); + void _add_decorators(const std::vector& decorators); + void compile_class(const std::vector& decorators={}); void _compile_f_args(FuncDecl_ decl, bool enable_type_hints); void compile_function(const std::vector& decorators={}); diff --git a/include/pocketpy/opcodes.h b/include/pocketpy/opcodes.h index 6f55c956..0f241544 100644 --- a/include/pocketpy/opcodes.h +++ b/include/pocketpy/opcodes.h @@ -118,6 +118,9 @@ OPCODE(UNPACK_EX) OPCODE(BEGIN_CLASS) OPCODE(END_CLASS) OPCODE(STORE_CLASS_ATTR) +OPCODE(BEGIN_CLASS_DECORATION) +OPCODE(END_CLASS_DECORATION) +OPCODE(ADD_CLASS_ANNOTATION) /**************************/ OPCODE(WITH_ENTER) OPCODE(WITH_EXIT) diff --git a/include/pocketpy/vm.h b/include/pocketpy/vm.h index a25979b5..f07a09b8 100644 --- a/include/pocketpy/vm.h +++ b/include/pocketpy/vm.h @@ -55,6 +55,8 @@ struct PyTypeInfo{ Str name; bool subclass_enabled; + std::vector annotated_fields; + // cached special methods // unary operators PyObject* (*m__repr__)(VM* vm, PyObject*) = nullptr; diff --git a/python/builtins.py b/python/builtins.py index 9e23a81f..59fba98c 100644 --- a/python/builtins.py +++ b/python/builtins.py @@ -5,6 +5,17 @@ def print(*args, sep=' ', end='\n'): s = sep.join([str(i) for i in args]) _sys.stdout.write(s + end) +def issubclass(cls, base): + if type(cls) is not type: + raise TypeError('issubclass() arg 1 must be a class') + if type(base) is not type: + raise TypeError('issubclass() arg 2 must be a class') + while cls is not None: + if cls is base: + return True + cls = cls.__base__ + return False + def _minmax_reduce(op, args, key): if key is None: if len(args) == 2: diff --git a/python/dataclasses.py b/python/dataclasses.py new file mode 100644 index 00000000..1c5b97cc --- /dev/null +++ b/python/dataclasses.py @@ -0,0 +1,59 @@ +def _wrapped__init__(self, *args, **kwargs): + cls = type(self) + cls_d = cls.__dict__ + fields: tuple[str] = cls.__annotations__ + i = 0 # index into args + for field in fields: + if field in kwargs: + setattr(self, field, kwargs.pop(field)) + else: + if i < len(args): + setattr(self, field, args[i]) + ++i + elif field in cls_d: # has default value + setattr(self, field, cls_d[field]) + else: + raise TypeError(f"{cls.__name__} missing required argument {field!r}") + if len(args) > i: + raise TypeError(f"{cls.__name__} takes {len(field)} positional arguments but {len(args)} were given") + if len(kwargs) > 0: + raise TypeError(f"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}") + +def _wrapped__repr__(self): + fields: tuple[str] = type(self).__annotations__ + obj_d = self.__dict__ + args: list = [f"{field}={obj_d[field]!r}" for field in fields] + return f"{type(self).__name__}({', '.join(args)})" + +def _wrapped__eq__(self, other): + if type(self) is not type(other): + return False + fields: tuple[str] = type(self).__annotations__ + for field in fields: + if getattr(self, field) != getattr(other, field): + return False + return True + +def dataclass(cls: type): + assert type(cls) is type + cls_d = cls.__dict__ + if '__init__' not in cls_d: + cls.__init__ = _wrapped__init__ + if '__repr__' not in cls_d: + cls.__repr__ = _wrapped__repr__ + if '__eq__' not in cls_d: + cls.__eq__ = _wrapped__eq__ + fields: tuple[str] = cls.__annotations__ + has_default = False + for field in fields: + if field in cls_d: + has_default = True + else: + if has_default: + raise TypeError(f"non-default argument {field!r} follows default argument") + return cls + +def asdict(obj) -> dict: + fields: tuple[str] = type(obj).__annotations__ + obj_d = obj.__dict__ + return {field: obj_d[field] for field in fields} diff --git a/src/ceval.cpp b/src/ceval.cpp index 9b9c3f14..d323a3ef 100644 --- a/src/ceval.cpp +++ b/src/ceval.cpp @@ -749,6 +749,8 @@ __NEXT_STEP:; } DISPATCH(); TARGET(END_CLASS) { PK_ASSERT(_curr_class != nullptr); + StrName _name(byte.arg); + frame->_module->attr().set(_name, _curr_class); _curr_class = nullptr; } DISPATCH(); TARGET(STORE_CLASS_ATTR){ @@ -760,6 +762,18 @@ __NEXT_STEP:; } _curr_class->attr().set(_name, _0); } DISPATCH(); + TARGET(BEGIN_CLASS_DECORATION){ + PUSH(_curr_class); + } DISPATCH(); + TARGET(END_CLASS_DECORATION){ + _curr_class = POPX(); + } DISPATCH(); + TARGET(ADD_CLASS_ANNOTATION) { + PK_ASSERT(_curr_class != nullptr); + StrName _name(byte.arg); + Type type = PK_OBJ_GET(Type, _curr_class); + _type_info(type)->annotated_fields.push_back(_name); + } DISPATCH(); /*****************************************/ TARGET(WITH_ENTER) call_method(POPX(), __enter__); @@ -818,8 +832,8 @@ __NEXT_STEP:; } DISPATCH(); #if !PK_ENABLE_COMPUTED_GOTO - static_assert(OP_DEC_GLOBAL == 107); - case 108: case 109: case 110: case 111: case 112: case 113: case 114: case 115: case 116: case 117: case 118: case 119: case 120: case 121: case 122: case 123: case 124: case 125: case 126: case 127: case 128: case 129: case 130: case 131: case 132: case 133: case 134: case 135: case 136: case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: case 255: FATAL_ERROR(); break; + static_assert(OP_DEC_GLOBAL == 110); + case 111: case 112: case 113: case 114: case 115: case 116: case 117: case 118: case 119: case 120: case 121: case 122: case 123: case 124: case 125: case 126: case 127: case 128: case 129: case 130: case 131: case 132: case 133: case 134: case 135: case 136: case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: case 255: FATAL_ERROR(); break; } #endif } diff --git a/src/cffi.cpp b/src/cffi.cpp index f83cc72b..c5684c4f 100644 --- a/src/cffi.cpp +++ b/src/cffi.cpp @@ -164,11 +164,12 @@ void add_module_c(VM* vm){ Type type_t; #define BIND_PRIMITIVE(T, CNAME) \ - vm->bind_func<1>(mod, CNAME "_", [](VM* vm, ArgsView args){ \ + vm->bind_func<1>(mod, CNAME "_", [](VM* vm, ArgsView args){ \ T val = CAST(T, args[0]); \ return VAR_T(C99Struct, &val, sizeof(T)); \ }); \ - type = vm->new_type_object(mod, CNAME "_p", VoidP::_type(vm)); \ + type = vm->new_type_object(mod, CNAME "_p", VoidP::_type(vm)); \ + mod->attr().set(CNAME "_p", type); \ type_t = PK_OBJ_GET(Type, type); \ vm->bind_method<0>(type, "read", [](VM* vm, ArgsView args){ \ VoidP& voidp = PK_OBJ_GET(VoidP, args[0]); \ diff --git a/src/compiler.cpp b/src/compiler.cpp index 7286b0ab..cd096b0c 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -712,8 +712,13 @@ __EAT_DOTS_END: decorators.push_back(ctx()->s_expr.popx()); if(!match_newlines_repl()) SyntaxError(); }while(match(TK("@"))); - consume(TK("def")); - compile_function(decorators); + + if(match(TK("class"))){ + compile_class(decorators); + }else{ + consume(TK("def")); + compile_function(decorators); + } } bool Compiler::try_compile_assignment(){ @@ -927,6 +932,12 @@ __EAT_DOTS_END: if(match(TK(":"))){ consume_type_hints(); is_typed_name = true; + + if(ctx()->is_compiling_class){ + // add to __annotations__ + NameExpr* ne = static_cast(ctx()->s_expr.top().get()); + ctx()->emit_(OP_ADD_CLASS_ANNOTATION, ne->name.index, BC_KEEPLINE); + } } } if(!try_compile_assignment()){ @@ -955,7 +966,18 @@ __EAT_DOTS_END: ctx()->s_expr.pop(); } - void Compiler::compile_class(){ + void Compiler::_add_decorators(const std::vector& decorators){ + // [obj] + for(auto it=decorators.rbegin(); it!=decorators.rend(); ++it){ + (*it)->emit_(ctx()); // [obj, f] + ctx()->emit_(OP_ROT_TWO, BC_NOARG, (*it)->line); // [f, obj] + ctx()->emit_(OP_LOAD_NULL, BC_NOARG, BC_KEEPLINE); // [f, obj, NULL] + ctx()->emit_(OP_ROT_TWO, BC_NOARG, BC_KEEPLINE); // [obj, NULL, f] + ctx()->emit_(OP_CALL, 1, (*it)->line); // [obj] + } + } + + void Compiler::compile_class(const std::vector& decorators){ consume(TK("@id")); int namei = StrName(prev().sv()).index; Expr_ base = nullptr; @@ -981,7 +1003,14 @@ __EAT_DOTS_END: ctx()->is_compiling_class = true; compile_block_body(); ctx()->is_compiling_class = false; - ctx()->emit_(OP_END_CLASS, BC_NOARG, BC_KEEPLINE); + + if(!decorators.empty()){ + ctx()->emit_(OP_BEGIN_CLASS_DECORATION, BC_NOARG, BC_KEEPLINE); + _add_decorators(decorators); + ctx()->emit_(OP_END_CLASS_DECORATION, BC_NOARG, BC_KEEPLINE); + } + + ctx()->emit_(OP_END_CLASS, namei, BC_KEEPLINE); } void Compiler::_compile_f_args(FuncDecl_ decl, bool enable_type_hints){ @@ -1077,14 +1106,8 @@ __EAT_DOTS_END: } ctx()->emit_(OP_LOAD_FUNCTION, ctx()->add_func_decl(decl), prev().line); - // add decorators - for(auto it=decorators.rbegin(); it!=decorators.rend(); ++it){ - (*it)->emit_(ctx()); - ctx()->emit_(OP_ROT_TWO, BC_NOARG, (*it)->line); - ctx()->emit_(OP_LOAD_NULL, BC_NOARG, BC_KEEPLINE); - ctx()->emit_(OP_ROT_TWO, BC_NOARG, BC_KEEPLINE); - ctx()->emit_(OP_CALL, 1, (*it)->line); - } + _add_decorators(decorators); + if(!ctx()->is_compiling_class){ auto e = make_expr(decl_name, name_scope()); e->emit_store(ctx()); diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index 137f2d60..22b07b4b 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -1146,29 +1146,6 @@ void init_builtins(VM* _vm) { return value; }); - // _vm->bind_method<0>("dict", "_data", [](VM* vm, ArgsView args) { - // Dict& self = _CAST(Dict&, args[0]); - // SStream ss; - // ss << "[\n"; - // for(int i=0; ipy_repr(item.first)); - // } - // if(item.second != nullptr){ - // value = CAST(Str&, vm->py_repr(item.second)); - // } - // int prev = self._nodes[i].prev; - // int next = self._nodes[i].next; - // ss << " [" << key << ", " << value << ", " << prev << ", " << next << "],\n"; - // } - // ss << "]\n"; - // vm->stdout_write(ss.str()); - // return vm->None; - // }); - _vm->bind__contains__(_vm->tp_dict, [](VM* vm, PyObject* obj, PyObject* key) { Dict& self = _CAST(Dict&, obj); return VAR(self.contains(key)); @@ -1615,6 +1592,15 @@ void VM::post_init(){ return self; // for generics }); + bind_property(_t(tp_type), "__annotations__", [](VM* vm, ArgsView args){ + PyTypeInfo* ti = vm->_type_info(PK_OBJ_GET(Type, args[0])); + Tuple t(ti->annotated_fields.size()); + for(int i=0; iannotated_fields.size(); i++){ + t[i] = VAR(ti->annotated_fields[i].sv()); + } + return VAR(std::move(t)); + }); + bind__repr__(tp_type, [](VM* vm, PyObject* self){ SStream ss; const PyTypeInfo& info = vm->_all_types[PK_OBJ_GET(Type, self)]; @@ -1679,7 +1665,7 @@ void VM::post_init(){ add_module_operator(this); add_module_csv(this); - for(const char* name: {"this", "functools", "heapq", "bisect", "pickle", "_long", "colorsys", "typing", "datetime"}){ + for(const char* name: {"this", "functools", "heapq", "bisect", "pickle", "_long", "colorsys", "typing", "datetime", "dataclasses"}){ _lazy_modules[name] = kPythonLibs[name]; } diff --git a/src/vm.cpp b/src/vm.cpp index 0b586eff..632aeeac 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -204,7 +204,6 @@ namespace pkpy{ name.sv(), subclass_enabled, }; - if(mod != nullptr) mod->attr().set(name, obj); _all_types.push_back(info); return obj; } diff --git a/tests/82_dataclasses.py b/tests/82_dataclasses.py new file mode 100644 index 00000000..81320231 --- /dev/null +++ b/tests/82_dataclasses.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass, asdict + +@dataclass +class A: + x: int + y: str = '123' + +assert repr(A(1)) == "A(x=1, y='123')" +assert repr(A(x=3)) == "A(x=3, y='123')" +assert repr(A(1, '555')) == "A(x=1, y='555')" +assert repr(A(x=7, y='555')) == "A(x=7, y='555')" + +assert asdict(A(1, '555')) == {'x': 1, 'y': '555'} + +assert A(1, 'N') == A(1, 'N') +assert A(1, 'N') != A(1, 'M') + +def wrapped(cls): + return int + +@wrapped +@wrapped +@wrapped +@wrapped +class A: + def __init__(self) -> None: + pass + +assert A('123') == 123 diff --git a/tests/99_builtin_func.py b/tests/99_builtin_func.py index 387856bd..939473c9 100644 --- a/tests/99_builtin_func.py +++ b/tests/99_builtin_func.py @@ -978,4 +978,13 @@ assert callable(isinstance) is True # builtin function assert id(0) is None -assert id(2**62) is not None \ No newline at end of file +assert id(2**62) is not None + +# test issubclass +assert issubclass(int, int) is True +assert issubclass(int, object) is True +assert issubclass(object, int) is False +assert issubclass(object, object) is True +assert issubclass(int, type) is False +assert issubclass(type, type) is True +assert issubclass(float, int) is False