From 60e666c12e883ebcab1a1198cc557270a735c24b Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Mon, 18 Sep 2023 00:23:37 +0800 Subject: [PATCH] fix https://github.com/blueloveTH/pocketpy/issues/131 --- include/pocketpy/str.h | 1 + src/ceval.cpp | 7 ++- src/pocketpy.cpp | 30 ++++++++--- tests/40_class_ex.py | 119 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 tests/40_class_ex.py diff --git a/include/pocketpy/str.h b/include/pocketpy/str.h index 36832d34..2f387aa1 100644 --- a/include/pocketpy/str.h +++ b/include/pocketpy/str.h @@ -190,6 +190,7 @@ const StrName __name__ = StrName::get("__name__"); const StrName __all__ = StrName::get("__all__"); const StrName __package__ = StrName::get("__package__"); const StrName __path__ = StrName::get("__path__"); +const StrName __class__ = StrName::get("__class__"); const StrName pk_id_add = StrName::get("add"); const StrName pk_id_set = StrName::get("set"); diff --git a/src/ceval.cpp b/src/ceval.cpp index 0c34c75e..49f26f18 100644 --- a/src/ceval.cpp +++ b/src/ceval.cpp @@ -659,11 +659,14 @@ __NEXT_STEP:; _0 = POPX(); _0->attr()._try_perfect_rehash(); DISPATCH(); - TARGET(STORE_CLASS_ATTR) + TARGET(STORE_CLASS_ATTR){ _name = StrName(byte.arg); _0 = POPX(); + if(is_non_tagged_type(_0, tp_function)){ + _0->attr().set(__class__, TOP()); + } TOP()->attr().set(_name, _0); - DISPATCH(); + } DISPATCH(); /*****************************************/ TARGET(WITH_ENTER) call_method(POPX(), __enter__); diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index ea53a69a..22855fdf 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -108,16 +108,32 @@ void init_builtins(VM* _vm) { #undef BIND_NUM_ARITH_OPT #undef BIND_NUM_LOGICAL_OPT - _vm->bind_builtin_func<2>("super", [](VM* vm, ArgsView args) { - vm->check_non_tagged_type(args[0], vm->tp_type); - Type type = PK_OBJ_GET(Type, args[0]); - if(!vm->isinstance(args[1], type)){ - Str _0 = obj_type_name(vm, PK_OBJ_GET(Type, vm->_t(args[1]))); + _vm->bind_builtin_func<-1>("super", [](VM* vm, ArgsView args) { + PyObject* class_arg = nullptr; + PyObject* self_arg = nullptr; + if(args.size() == 2){ + class_arg = args[0]; + self_arg = args[1]; + }else if(args.size() == 0){ + FrameId frame = vm->top_frame(); + if(frame->_callable != nullptr){ + class_arg = frame->_callable->attr().try_get(__class__); + if(frame->_locals.size() > 0) self_arg = frame->_locals[0]; + } + if(class_arg == nullptr || self_arg == nullptr){ + vm->TypeError("super(): unable to determine the class context, use super(class, self) instead"); + } + }else{ + vm->TypeError("super() takes 0 or 2 arguments"); + } + vm->check_non_tagged_type(class_arg, vm->tp_type); + Type type = PK_OBJ_GET(Type, class_arg); + if(!vm->isinstance(self_arg, type)){ + Str _0 = obj_type_name(vm, PK_OBJ_GET(Type, vm->_t(self_arg))); Str _1 = obj_type_name(vm, type); vm->TypeError("super(): " + _0.escape() + " is not an instance of " + _1.escape()); } - Type base = vm->_all_types[type].base; - return vm->heap.gcnew(vm->tp_super, args[1], base); + return vm->heap.gcnew(vm->tp_super, self_arg, vm->_all_types[type].base); }); _vm->bind_builtin_func<2>("isinstance", [](VM* vm, ArgsView args) { diff --git a/tests/40_class_ex.py b/tests/40_class_ex.py new file mode 100644 index 00000000..7856aeab --- /dev/null +++ b/tests/40_class_ex.py @@ -0,0 +1,119 @@ +class A: + def __init__(self, a, b): + self.a = a + self.b = b + + def add(self): + return self.a + self.b + + def sub(self): + return self.a - self.b + +a = A(1, 2) +assert a.add() == 3 +assert a.sub() == -1 + +assert A.__base__ is object + +class B(A): + def __init__(self, a, b, c): + super().__init__(a, b) + self.c = c + + def add(self): + return self.a + self.b + self.c + + def sub(self): + return self.a - self.b - self.c + +assert B.__base__ is A + +b = B(1, 2, 3) +assert b.add() == 6 +assert b.sub() == -4 + +class C(B): + def __init__(self, a, b, c, d): + super().__init__(a, b, c) + self.d = d + + def add(self): + return self.a + self.b + self.c + self.d + + def sub(self): + return self.a - self.b - self.c - self.d + +assert C.__base__ is B + +c = C(1, 2, 3, 4) +assert c.add() == 10 +assert c.sub() == -8 + +class D(C): + def __init__(self, a, b, c, d, e): + super().__init__(a, b, c, d) + self.e = e + + def add(self): + return super().add() + self.e + + def sub(self): + return super().sub() - self.e + +assert D.__base__ is C + +d = D(1, 2, 3, 4, 5) +assert d.add() == 15 +assert d.sub() == -13 + +assert isinstance(1, int) +assert isinstance(1, object) +assert isinstance(C, type) +assert isinstance(C, object) +assert isinstance(d, object) +assert isinstance(d, C) +assert isinstance(d, B) +assert isinstance(d, A) +assert isinstance(object, object) +assert isinstance(type, object) + +assert isinstance(1, (float, int)) +assert isinstance(1, (float, object)) +assert not isinstance(1, (float, str)) +assert isinstance(object, (int, type, float)) +assert not isinstance(object, (int, float, str)) + +try: + isinstance(1, (1, 2)) + exit(1) +except TypeError: + pass + +try: + isinstance(1, 1) + exit(1) +except TypeError: + pass + +class A: + a = 1 + b = 2 + +assert A.a == 1 +assert A.b == 2 + +class B(A): + b = 3 + c = 4 + +# assert B.a == 1 ...bug here +assert B.b == 3 +assert B.c == 4 + +import c + +class A(c.void_p): + pass + +a = A() +assert repr(a).startswith('