From c7935564c3329740c93a90b44286b5015ac09692 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Wed, 4 Dec 2024 23:34:04 +0800 Subject: [PATCH] fix a bug of `super` --- src/interpreter/ceval.c | 11 ++++++++--- src/public/modules.c | 12 +++++++++--- tests/41_class_ex.py | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 3e282dda..d489cb55 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -60,7 +60,7 @@ static bool stack_format_object(VM* self, c11_sv spec); case RES_RETURN: PUSH(&self->last_retval); break; \ case RES_CALL: frame = self->top_frame; goto __NEXT_FRAME; \ case RES_ERROR: goto __ERROR; \ - default: c11__unreachable(); \ + default: c11__unreachable(); \ } \ } while(0) @@ -974,8 +974,13 @@ FrameResult VM__run_top_frame(VM* self) { case OP_STORE_CLASS_ATTR: { assert(self->__curr_class); py_Name name = byte.arg; - if(py_istype(TOP(), tp_function)) { - Function* ud = py_touserdata(TOP()); + // TOP() can be a function, classmethod or custom decorator + py_Ref actual_func = TOP(); + if(actual_func->type == tp_classmethod) { + actual_func = py_getslot(actual_func, 0); + } + if(actual_func->type == tp_function) { + Function* ud = py_touserdata(actual_func); ud->clazz = self->__curr_class->_obj; } py_setdict(self->__curr_class, name, TOP()); diff --git a/src/public/modules.c b/src/public/modules.c index 20bae895..92b578c3 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -763,9 +763,15 @@ static bool super__new__(int argc, py_Ref argv) { if(argc == 1) { // super() if(frame->has_function) { - Function* func = py_touserdata(frame->p0); - *class_arg = *(py_Type*)PyObject__userdata(func->clazz); - if(frame->co->nlocals > 0) self_arg = &frame->locals[0]; + py_TValue* callable = frame->p0; + if(callable->type == tp_boundmethod) callable = py_getslot(frame->p0, 1); + if(callable->type == tp_function) { + Function* func = py_touserdata(callable); + if(func->clazz != NULL) { + *class_arg = *(py_Type*)PyObject__userdata(func->clazz); + if(frame->co->nlocals > 0) self_arg = &frame->locals[0]; + } + } } if(class_arg == 0 || self_arg == NULL) return RuntimeError("super(): no arguments"); } else if(argc == 3) { diff --git a/tests/41_class_ex.py b/tests/41_class_ex.py index 2a234c2a..e52d811d 100644 --- a/tests/41_class_ex.py +++ b/tests/41_class_ex.py @@ -154,3 +154,17 @@ assert B.static_method(123) == 123 assert B.class_method('123') == 'B123' assert B().static_method(123) == 123 assert B().class_method('123') == 'B123' + +# test super() with classmethod +class BaseClass: + @classmethod + def f(cls): + return 'BaseClass' + +class DerivedClass(BaseClass): + @classmethod + def f(cls): + return super().f() + + +assert DerivedClass.f() == 'BaseClass' \ No newline at end of file