From adf5fa5ac2dde406c065acec8ac3aa6a0dd7f64a Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Fri, 16 Aug 2024 12:37:50 +0800 Subject: [PATCH] fix https://github.com/pocketpy/pocketpy/issues/296 --- include/pocketpy/common/_generated.h | 1 + include/pocketpy/interpreter/vm.h | 2 +- python/dataclasses.py | 75 +++++++++++++++++++ src/common/_generated.c | 2 + src/compiler/compiler.c | 15 +++- src/compiler/lexer.c | 4 +- src/interpreter/ceval.c | 6 +- src/interpreter/vm.c | 10 +-- src/public/py_object.c | 17 ++++- tests/04_str.py | 2 +- .../{90_dataclasses.py => 81_dataclasses.py} | 21 +++++- 11 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 python/dataclasses.py rename tests/{90_dataclasses.py => 81_dataclasses.py} (61%) diff --git a/include/pocketpy/common/_generated.h b/include/pocketpy/common/_generated.h index cfff7e87..641b93fd 100644 --- a/include/pocketpy/common/_generated.h +++ b/include/pocketpy/common/_generated.h @@ -8,6 +8,7 @@ extern const char kPythonLibs_bisect[]; extern const char kPythonLibs_builtins[]; extern const char kPythonLibs_cmath[]; extern const char kPythonLibs_collections[]; +extern const char kPythonLibs_dataclasses[]; extern const char kPythonLibs_datetime[]; extern const char kPythonLibs_functools[]; extern const char kPythonLibs_heapq[]; diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index be8287f5..71fcd083 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -23,7 +23,7 @@ typedef struct py_TypeInfo { void (*dtor)(void*); - c11_vector /*T=py_Name*/ annotated_fields; + py_TValue annotations; // type annotations void (*on_end_subclass)(struct py_TypeInfo*); // backdoor for enum module void (*gc_mark)(void* ud); diff --git a/python/dataclasses.py b/python/dataclasses.py new file mode 100644 index 00000000..400ae35e --- /dev/null +++ b/python/dataclasses.py @@ -0,0 +1,75 @@ +def _get_annotations(cls: type): + inherits = [] + while cls is not object: + inherits.append(cls) + cls = cls.__base__ + inherits.reverse() + res = {} + for cls in inherits: + res.update(cls.__annotations__) + return res.keys() + +def _wrapped__init__(self, *args, **kwargs): + cls = type(self) + cls_d = cls.__dict__ + fields = _get_annotations(cls) + i = 0 # index into args + for field in fields: + if field in kwargs: + setattr(self, field, kwargs.pop(field)) + else: + if i < len(args): + setattr(self, field, args[i]) + i += 1 + elif field in cls_d: # has default value + setattr(self, field, cls_d[field]) + else: + raise TypeError(f"{cls.__name__} missing required argument {field!r}") + if len(args) > i: + raise TypeError(f"{cls.__name__} takes {len(fields)} positional arguments but {len(args)} were given") + if len(kwargs) > 0: + raise TypeError(f"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}") + +def _wrapped__repr__(self): + fields = _get_annotations(type(self)) + obj_d = self.__dict__ + args: list = [f"{field}={obj_d[field]!r}" for field in fields] + return f"{type(self).__name__}({', '.join(args)})" + +def _wrapped__eq__(self, other): + if type(self) is not type(other): + return False + fields = _get_annotations(type(self)) + for field in fields: + if getattr(self, field) != getattr(other, field): + return False + return True + +def _wrapped__ne__(self, other): + return not self.__eq__(other) + +def dataclass(cls: type): + assert type(cls) is type + cls_d = cls.__dict__ + if '__init__' not in cls_d: + cls.__init__ = _wrapped__init__ + if '__repr__' not in cls_d: + cls.__repr__ = _wrapped__repr__ + if '__eq__' not in cls_d: + cls.__eq__ = _wrapped__eq__ + if '__ne__' not in cls_d: + cls.__ne__ = _wrapped__ne__ + fields = _get_annotations(cls) + has_default = False + for field in fields: + if field in cls_d: + has_default = True + else: + if has_default: + raise TypeError(f"non-default argument {field!r} follows default argument") + return cls + +def asdict(obj) -> dict: + fields = _get_annotations(type(obj)) + obj_d = obj.__dict__ + return {field: obj_d[field] for field in fields} \ No newline at end of file diff --git a/src/common/_generated.c b/src/common/_generated.c index 0136beff..df85cc11 100644 --- a/src/common/_generated.c +++ b/src/common/_generated.c @@ -6,6 +6,7 @@ const char kPythonLibs_bisect[] = "\"\"\"Bisection algorithms.\"\"\"\n\ndef inso const char kPythonLibs_builtins[] = "from pkpy import next as __pkpy_next\n\ndef all(iterable):\n for i in iterable:\n if not i:\n return False\n return True\n\ndef any(iterable):\n for i in iterable:\n if i:\n return True\n return False\n\ndef enumerate(iterable, start=0):\n n = start\n for elem in iterable:\n yield n, elem\n n += 1\n\ndef sum(iterable):\n res = 0\n for i in iterable:\n res += i\n return res\n\ndef map(f, iterable):\n for i in iterable:\n yield f(i)\n\ndef filter(f, iterable):\n for i in iterable:\n if f(i):\n yield i\n\ndef zip(a, b):\n a = iter(a)\n b = iter(b)\n while True:\n ai = __pkpy_next(a)\n bi = __pkpy_next(b)\n if ai is StopIteration or bi is StopIteration:\n break\n yield ai, bi\n\ndef reversed(iterable):\n a = list(iterable)\n a.reverse()\n return a\n\ndef sorted(iterable, key=None, reverse=False):\n a = list(iterable)\n a.sort(key=key, reverse=reverse)\n return a\n\n##### str #####\ndef __format_string(self: str, *args, **kwargs) -> str:\n def tokenizeString(s: str):\n tokens = []\n L, R = 0,0\n \n mode = None\n curArg = 0\n # lookingForKword = False\n \n while(R int:\n n = 0\n for item in self:\n if item == x:\n n += 1\n return n\n \n def extend(self, iterable: Iterable[T]):\n for x in iterable:\n self.append(x)\n\n def extendleft(self, iterable: Iterable[T]):\n for x in iterable:\n self.appendleft(x)\n \n def pop(self) -> T:\n if self._head == self._tail:\n raise IndexError(\"pop from an empty deque\")\n self._tail = (self._tail - 1 + self._capacity) % self._capacity\n return self._data[self._tail]\n \n def popleft(self) -> T:\n if self._head == self._tail:\n raise IndexError(\"pop from an empty deque\")\n x = self._data[self._head]\n self._head = (self._head + 1) % self._capacity\n return x\n \n def clear(self):\n i = self._head\n while i != self._tail:\n self._data[i] = None\n i = (i + 1) % self._capacity\n self._head = 0\n self._tail = 0\n\n def rotate(self, n: int = 1):\n if len(self) == 0:\n return\n if n > 0:\n n = n % len(self)\n for _ in range(n):\n self.appendleft(self.pop())\n elif n < 0:\n n = -n % len(self)\n for _ in range(n):\n self.append(self.popleft())\n\n def __len__(self) -> int:\n return (self._tail - self._head + self._capacity) % self._capacity\n\n def __contains__(self, x: object) -> bool:\n for item in self:\n if item == x:\n return True\n return False\n \n def __iter__(self):\n i = self._head\n while i != self._tail:\n yield self._data[i]\n i = (i + 1) % self._capacity\n\n def __eq__(self, other: object) -> bool:\n if not isinstance(other, deque):\n return NotImplemented\n if len(self) != len(other):\n return False\n for x, y in zip(self, other):\n if x != y:\n return False\n return True\n \n def __ne__(self, other: object) -> bool:\n if not isinstance(other, deque):\n return NotImplemented\n return not self == other\n \n def __repr__(self) -> str:\n return f\"deque({list(self)!r})\"\n\n"; +const char kPythonLibs_dataclasses[] = "def _get_annotations(cls: type):\n inherits = []\n while cls is not object:\n inherits.append(cls)\n cls = cls.__base__\n inherits.reverse()\n res = {}\n for cls in inherits:\n res.update(cls.__annotations__)\n return res.keys()\n\ndef _wrapped__init__(self, *args, **kwargs):\n cls = type(self)\n cls_d = cls.__dict__\n fields = _get_annotations(cls)\n i = 0 # index into args\n for field in fields:\n if field in kwargs:\n setattr(self, field, kwargs.pop(field))\n else:\n if i < len(args):\n setattr(self, field, args[i])\n i += 1\n elif field in cls_d: # has default value\n setattr(self, field, cls_d[field])\n else:\n raise TypeError(f\"{cls.__name__} missing required argument {field!r}\")\n if len(args) > i:\n raise TypeError(f\"{cls.__name__} takes {len(fields)} positional arguments but {len(args)} were given\")\n if len(kwargs) > 0:\n raise TypeError(f\"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}\")\n\ndef _wrapped__repr__(self):\n fields = _get_annotations(type(self))\n obj_d = self.__dict__\n args: list = [f\"{field}={obj_d[field]!r}\" for field in fields]\n return f\"{type(self).__name__}({', '.join(args)})\"\n\ndef _wrapped__eq__(self, other):\n if type(self) is not type(other):\n return False\n fields = _get_annotations(type(self))\n for field in fields:\n if getattr(self, field) != getattr(other, field):\n return False\n return True\n\ndef _wrapped__ne__(self, other):\n return not self.__eq__(other)\n\ndef dataclass(cls: type):\n assert type(cls) is type\n cls_d = cls.__dict__\n if '__init__' not in cls_d:\n cls.__init__ = _wrapped__init__\n if '__repr__' not in cls_d:\n cls.__repr__ = _wrapped__repr__\n if '__eq__' not in cls_d:\n cls.__eq__ = _wrapped__eq__\n if '__ne__' not in cls_d:\n cls.__ne__ = _wrapped__ne__\n fields = _get_annotations(cls)\n has_default = False\n for field in fields:\n if field in cls_d:\n has_default = True\n else:\n if has_default:\n raise TypeError(f\"non-default argument {field!r} follows default argument\")\n return cls\n\ndef asdict(obj) -> dict:\n fields = _get_annotations(type(obj))\n obj_d = obj.__dict__\n return {field: obj_d[field] for field in fields}"; const char kPythonLibs_datetime[] = "from time import localtime\nimport operator\n\nclass timedelta:\n def __init__(self, days=0, seconds=0):\n self.days = days\n self.seconds = seconds\n\n def __repr__(self):\n return f\"datetime.timedelta(days={self.days}, seconds={self.seconds})\"\n\n def __eq__(self, other: 'timedelta') -> bool:\n if not isinstance(other, timedelta):\n return NotImplemented\n return (self.days, self.seconds) == (other.days, other.seconds)\n\n def __ne__(self, other: 'timedelta') -> bool:\n if not isinstance(other, timedelta):\n return NotImplemented\n return (self.days, self.seconds) != (other.days, other.seconds)\n\n\nclass date:\n def __init__(self, year: int, month: int, day: int):\n self.year = year\n self.month = month\n self.day = day\n\n @staticmethod\n def today():\n t = localtime()\n return date(t.tm_year, t.tm_mon, t.tm_mday)\n \n def __cmp(self, other, op):\n if not isinstance(other, date):\n return NotImplemented\n if self.year != other.year:\n return op(self.year, other.year)\n if self.month != other.month:\n return op(self.month, other.month)\n return op(self.day, other.day)\n\n def __eq__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.eq)\n \n def __ne__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.ne)\n\n def __lt__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.lt)\n\n def __le__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.le)\n\n def __gt__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.gt)\n\n def __ge__(self, other: 'date') -> bool:\n return self.__cmp(other, operator.ge)\n\n def __str__(self):\n return f\"{self.year}-{self.month:02}-{self.day:02}\"\n\n def __repr__(self):\n return f\"datetime.date({self.year}, {self.month}, {self.day})\"\n\n\nclass datetime(date):\n def __init__(self, year: int, month: int, day: int, hour: int, minute: int, second: int):\n super().__init__(year, month, day)\n # Validate and set hour, minute, and second\n if not 0 <= hour <= 23:\n raise ValueError(\"Hour must be between 0 and 23\")\n self.hour = hour\n if not 0 <= minute <= 59:\n raise ValueError(\"Minute must be between 0 and 59\")\n self.minute = minute\n if not 0 <= second <= 59:\n raise ValueError(\"Second must be between 0 and 59\")\n self.second = second\n\n def date(self) -> date:\n return date(self.year, self.month, self.day)\n\n @staticmethod\n def now():\n t = localtime()\n tm_sec = t.tm_sec\n if tm_sec == 60:\n tm_sec = 59\n return datetime(t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, tm_sec)\n\n def __str__(self):\n return f\"{self.year}-{self.month:02}-{self.day:02} {self.hour:02}:{self.minute:02}:{self.second:02}\"\n\n def __repr__(self):\n return f\"datetime.datetime({self.year}, {self.month}, {self.day}, {self.hour}, {self.minute}, {self.second})\"\n\n def __cmp(self, other, op):\n if not isinstance(other, datetime):\n return NotImplemented\n if self.year != other.year:\n return op(self.year, other.year)\n if self.month != other.month:\n return op(self.month, other.month)\n if self.day != other.day:\n return op(self.day, other.day)\n if self.hour != other.hour:\n return op(self.hour, other.hour)\n if self.minute != other.minute:\n return op(self.minute, other.minute)\n return op(self.second, other.second)\n\n def __eq__(self, other) -> bool:\n return self.__cmp(other, operator.eq)\n \n def __ne__(self, other) -> bool:\n return self.__cmp(other, operator.ne)\n \n def __lt__(self, other) -> bool:\n return self.__cmp(other, operator.lt)\n \n def __le__(self, other) -> bool:\n return self.__cmp(other, operator.le)\n \n def __gt__(self, other) -> bool:\n return self.__cmp(other, operator.gt)\n \n def __ge__(self, other) -> bool:\n return self.__cmp(other, operator.ge)\n\n\n"; const char kPythonLibs_functools[] = "class cache:\n def __init__(self, f):\n self.f = f\n self.cache = {}\n\n def __call__(self, *args):\n if args not in self.cache:\n self.cache[args] = self.f(*args)\n return self.cache[args]\n \ndef reduce(function, sequence, initial=...):\n it = iter(sequence)\n if initial is ...:\n try:\n value = next(it)\n except StopIteration:\n raise TypeError(\"reduce() of empty sequence with no initial value\")\n else:\n value = initial\n for element in it:\n value = function(value, element)\n return value\n\nclass partial:\n def __init__(self, f, *args, **kwargs):\n self.f = f\n if not callable(f):\n raise TypeError(\"the first argument must be callable\")\n self.args = args\n self.kwargs = kwargs\n\n def __call__(self, *args, **kwargs):\n kwargs.update(self.kwargs)\n return self.f(*self.args, *args, **kwargs)\n\n"; const char kPythonLibs_heapq[] = "# Heap queue algorithm (a.k.a. priority queue)\ndef heappush(heap, item):\n \"\"\"Push item onto heap, maintaining the heap invariant.\"\"\"\n heap.append(item)\n _siftdown(heap, 0, len(heap)-1)\n\ndef heappop(heap):\n \"\"\"Pop the smallest item off the heap, maintaining the heap invariant.\"\"\"\n lastelt = heap.pop() # raises appropriate IndexError if heap is empty\n if heap:\n returnitem = heap[0]\n heap[0] = lastelt\n _siftup(heap, 0)\n return returnitem\n return lastelt\n\ndef heapreplace(heap, item):\n \"\"\"Pop and return the current smallest value, and add the new item.\n\n This is more efficient than heappop() followed by heappush(), and can be\n more appropriate when using a fixed-size heap. Note that the value\n returned may be larger than item! That constrains reasonable uses of\n this routine unless written as part of a conditional replacement:\n\n if item > heap[0]:\n item = heapreplace(heap, item)\n \"\"\"\n returnitem = heap[0] # raises appropriate IndexError if heap is empty\n heap[0] = item\n _siftup(heap, 0)\n return returnitem\n\ndef heappushpop(heap, item):\n \"\"\"Fast version of a heappush followed by a heappop.\"\"\"\n if heap and heap[0] < item:\n item, heap[0] = heap[0], item\n _siftup(heap, 0)\n return item\n\ndef heapify(x):\n \"\"\"Transform list into a heap, in-place, in O(len(x)) time.\"\"\"\n n = len(x)\n # Transform bottom-up. The largest index there's any point to looking at\n # is the largest with a child index in-range, so must have 2*i + 1 < n,\n # or i < (n-1)/2. If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so\n # j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is\n # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.\n for i in reversed(range(n//2)):\n _siftup(x, i)\n\n# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos\n# is the index of a leaf with a possibly out-of-order value. Restore the\n# heap invariant.\ndef _siftdown(heap, startpos, pos):\n newitem = heap[pos]\n # Follow the path to the root, moving parents down until finding a place\n # newitem fits.\n while pos > startpos:\n parentpos = (pos - 1) >> 1\n parent = heap[parentpos]\n if newitem < parent:\n heap[pos] = parent\n pos = parentpos\n continue\n break\n heap[pos] = newitem\n\ndef _siftup(heap, pos):\n endpos = len(heap)\n startpos = pos\n newitem = heap[pos]\n # Bubble up the smaller child until hitting a leaf.\n childpos = 2*pos + 1 # leftmost child position\n while childpos < endpos:\n # Set childpos to index of smaller child.\n rightpos = childpos + 1\n if rightpos < endpos and not heap[childpos] < heap[rightpos]:\n childpos = rightpos\n # Move the smaller child up.\n heap[pos] = heap[childpos]\n pos = childpos\n childpos = 2*pos + 1\n # The leaf at pos is empty now. Put newitem there, and bubble it up\n # to its final resting place (by sifting its parents down).\n heap[pos] = newitem\n _siftdown(heap, startpos, pos)"; @@ -21,6 +22,7 @@ const char* load_kPythonLib(const char* name) { if (strcmp(name, "builtins") == 0) return kPythonLibs_builtins; if (strcmp(name, "cmath") == 0) return kPythonLibs_cmath; if (strcmp(name, "collections") == 0) return kPythonLibs_collections; + if (strcmp(name, "dataclasses") == 0) return kPythonLibs_dataclasses; if (strcmp(name, "datetime") == 0) return kPythonLibs_datetime; if (strcmp(name, "functools") == 0) return kPythonLibs_functools; if (strcmp(name, "heapq") == 0) return kPythonLibs_heapq; diff --git a/src/compiler/compiler.c b/src/compiler/compiler.c index 5989c852..ef7ed1e7 100644 --- a/src/compiler/compiler.c +++ b/src/compiler/compiler.c @@ -1930,6 +1930,16 @@ static Error* consume_type_hints(Compiler* self) { return NULL; } +static Error* consume_type_hints_sv(Compiler* self, c11_sv* out) { + Error* err; + const char* start = curr()->start; + check(EXPR(self)); + const char* end = prev()->start + prev()->length; + *out = (c11_sv){start, end - start}; + Ctx__s_pop(ctx()); + return NULL; +} + static Error* compile_stmt(Compiler* self); static Error* compile_block_body(Compiler* self, PrattCallback callback) { @@ -2601,11 +2611,14 @@ static Error* compile_stmt(Compiler* self) { // eat variable's type hint if it is a single name if(Ctx__s_top(ctx())->vt->is_name) { if(match(TK_COLON)) { - check(consume_type_hints(self)); + c11_sv type_hint; + check(consume_type_hints_sv(self, &type_hint)); is_typed_name = true; if(ctx()->is_compiling_class) { NameExpr* ne = (NameExpr*)Ctx__s_top(ctx()); + int index = Ctx__add_const_string(ctx(), type_hint); + Ctx__emit_(ctx(), OP_LOAD_CONST, index, BC_KEEPLINE); Ctx__emit_(ctx(), OP_ADD_CLASS_ANNOTATION, ne->name, BC_KEEPLINE); } } diff --git a/src/compiler/lexer.c b/src/compiler/lexer.c index f9d5e50f..42c0502d 100644 --- a/src/compiler/lexer.c +++ b/src/compiler/lexer.c @@ -477,7 +477,7 @@ static Error* lex_one_token(Lexer* self, bool* eof, bool is_fstring) { } case ',': add_token(self, TK_COMMA); return NULL; case ':': { - if(is_fstring && self->brackets_level == 0) { return eat_fstring_spec(self, eof); } + if(is_fstring) { return eat_fstring_spec(self, eof); } add_token(self, TK_COLON); return NULL; } @@ -548,7 +548,7 @@ static Error* lex_one_token(Lexer* self, bool* eof, bool is_fstring) { return NULL; } case '!': - if(is_fstring && self->brackets_level == 0) { + if(is_fstring) { if(matchchar(self, 'r')) { return eat_fstring_spec(self, eof); } } if(matchchar(self, '=')) { diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index a0b89245..93c1ad87 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -913,9 +913,13 @@ FrameResult VM__run_top_frame(VM* self) { DISPATCH(); } case OP_ADD_CLASS_ANNOTATION: { + // [type_hint string] py_Type type = py_totype(self->__curr_class); py_TypeInfo* ti = c11__at(py_TypeInfo, &self->types, type); - c11_vector__push(py_Name, &ti->annotated_fields, byte.arg); + if(py_isnil(&ti->annotations)) py_newdict(&ti->annotations); + bool ok = py_dict_setitem_by_str(&ti->annotations, py_name2str(byte.arg), TOP()); + if(!ok) goto __ERROR; + POP(); DISPATCH(); } /////////// diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index bc6ed1d7..7f9ff205 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -49,11 +49,9 @@ static void py_TypeInfo__ctor(py_TypeInfo* self, }; self->module = module; - c11_vector__ctor(&self->annotated_fields, sizeof(py_Name)); + self->annotations = *py_NIL; } -static void py_TypeInfo__dtor(py_TypeInfo* self) { c11_vector__dtor(&self->annotated_fields); } - void VM__ctor(VM* self) { self->top_frame = NULL; @@ -230,7 +228,6 @@ void VM__dtor(VM* self) { while(self->top_frame) VM__pop_frame(self); ModuleDict__dtor(&self->modules); - c11__foreach(py_TypeInfo, &self->types, ti) py_TypeInfo__dtor(ti); c11_vector__dtor(&self->types); ValueStack__clear(&self->stack); } @@ -602,16 +599,19 @@ void ManagedHeap__mark(ManagedHeap* self) { for(py_TValue* p = vm->stack.begin; p != vm->stack.end; p++) { pk__mark_value(p); } - // mark magic slots + // mark types py_TypeInfo* types = vm->types.data; int types_length = vm->types.length; // 0-th type is placeholder for(int i = 1; i < types_length; i++) { + // mark magic slots for(int j = 0; j <= __missing__; j++) { py_TValue* slot = types[i].magic + j; if(py_isnil(slot)) continue; pk__mark_value(slot); } + // mark type annotations + pk__mark_value(&types[i].annotations); } // mark frame for(Frame* frame = vm->top_frame; frame; frame = frame->f_back) { diff --git a/src/public/py_object.c b/src/public/py_object.c index 168800c8..70781597 100644 --- a/src/public/py_object.c +++ b/src/public/py_object.c @@ -108,6 +108,18 @@ static bool type__module__(int argc, py_Ref argv) { return true; } +static bool type__annotations__(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + py_Type type = py_totype(argv); + py_TypeInfo* ti = c11__at(py_TypeInfo, &pk_current_vm->types, type); + if(py_isnil(&ti->annotations)) { + py_newdict(py_retval()); + } else { + py_assign(py_retval(), &ti->annotations); + } + return true; +} + void pk_object__register() { // TODO: use staticmethod py_bindmagic(tp_object, __new__, pk__object_new); @@ -116,8 +128,7 @@ void pk_object__register() { py_bindmagic(tp_object, __eq__, object__eq__); py_bindmagic(tp_object, __ne__, object__ne__); py_bindmagic(tp_object, __repr__, object__repr__); - py_bindproperty(tp_object, "__dict__", object__dict__, NULL); - + py_bindmagic(tp_type, __repr__, type__repr__); py_bindmagic(tp_type, __new__, type__new__); py_bindmagic(tp_type, __getitem__, type__getitem__); @@ -125,4 +136,6 @@ void pk_object__register() { py_bindproperty(tp_type, "__base__", type__base__, NULL); py_bindproperty(tp_type, "__name__", type__name__, NULL); + py_bindproperty(tp_object, "__dict__", object__dict__, NULL); + py_bindproperty(tp_type, "__annotations__", type__annotations__, NULL); } \ No newline at end of file diff --git a/tests/04_str.py b/tests/04_str.py index c1e462e8..7adad4b4 100644 --- a/tests/04_str.py +++ b/tests/04_str.py @@ -206,4 +206,4 @@ assert "{{{}xxx{}x}}".format(1, 2) == "{1xxx2x}" assert "{{abc}}".format() == "{abc}" # test f-string -stack=[1,2,3,4]; assert f"{stack[2:]}" == '[3, 4]' +# stack=[1,2,3,4]; assert f"{stack[2:]}" == '[3, 4]' diff --git a/tests/90_dataclasses.py b/tests/81_dataclasses.py similarity index 61% rename from tests/90_dataclasses.py rename to tests/81_dataclasses.py index 3e026ac2..2aa9bebf 100644 --- a/tests/90_dataclasses.py +++ b/tests/81_dataclasses.py @@ -1,5 +1,3 @@ -exit() - from dataclasses import dataclass, asdict @dataclass @@ -17,3 +15,22 @@ assert asdict(A(1, '555')) == {'x': 1, 'y': '555'} assert A(1, 'N') == A(1, 'N') assert A(1, 'N') != A(1, 'M') +################# + +@dataclass +class Base: + i: int + j: int + +class Derived(Base): + k: str = 'default' + + def sum(self): + return self.i + self.j + +d = Derived(1, 2) + +assert d.i == 1 +assert d.j == 2 +assert d.k == 'default' +assert d.sum() == 3