diff --git a/src/obj.h b/src/obj.h index ec71f375..b873aeff 100644 --- a/src/obj.h +++ b/src/obj.h @@ -73,7 +73,7 @@ public: }; typedef std::shared_ptr _Func; -typedef std::variant<_Int,_Float,bool,_Str,PyVarList,_CppFunc,_Func,std::shared_ptr<_Iterator>,_BoundedMethod,_Range,_Slice,_Pointer> _Value; +typedef std::variant,_BoundedMethod,_Range,_Slice,_Pointer> _Value; const int _SIZEOF_VALUE = sizeof(_Value); diff --git a/src/pocketpy.h b/src/pocketpy.h index bcafbc21..36d70a7b 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -56,6 +56,13 @@ void __initializeBuiltinFunctions(VM* _vm) { return vm->PyStr(tvm->readStdin()); }); + _vm->bindBuiltinFunc("super", [](VM* vm, const pkpy::ArgList& args) { + vm->__checkArgSize(args, 0); + auto it = vm->topFrame()->f_locals.find("self"_c); + if(it == vm->topFrame()->f_locals.end()) vm->typeError("super() can only be called in a class method"); + return vm->newObject(vm->_tp_super, it->second); + }); + _vm->bindBuiltinFunc("eval", [](VM* vm, const pkpy::ArgList& args) { vm->__checkArgSize(args, 1); const _Str& expr = vm->PyStr_AS_C(args[0]); diff --git a/src/vm.h b/src/vm.h index 41552de3..df0bfd5e 100644 --- a/src/vm.h +++ b/src/vm.h @@ -588,16 +588,34 @@ public: } PyVarOrNull getAttr(const PyVar& obj, const _Str& name, bool throw_err=true) { - auto it = obj->attribs.find(name); - if(it != obj->attribs.end()) return it->second; + PyVarDict::iterator it; + PyObject* cls; + + if(obj->isType(_tp_super)){ + const PyVar* root = &obj; + int depth = 1; + while(true){ + root = &std::get((*root)->_native); + if(!(*root)->isType(_tp_super)) break; + depth++; + } + cls = (*root)->_type.get(); + for(int i=0; iattribs[__base__].get(); + + it = (*root)->attribs.find(name); + if(it != (*root)->attribs.end()) return it->second; + }else{ + it = obj->attribs.find(name); + if(it != obj->attribs.end()) return it->second; + cls = obj->_type.get(); + } - PyObject* cls = obj->_type.get(); while(cls != None.get()) { it = cls->attribs.find(name); if(it != cls->attribs.end()){ PyVar valueFromCls = it->second; if(valueFromCls->isType(_tp_function) || valueFromCls->isType(_tp_native_function)){ - return PyBoundedMethod({obj, valueFromCls}); + return PyBoundedMethod({obj, std::move(valueFromCls)}); }else{ return valueFromCls; } @@ -609,11 +627,29 @@ public: } inline void setAttr(PyVar& obj, const _Str& name, const PyVar& value) { - obj->attribs[name] = value; + if(obj->isType(_tp_super)){ + const PyVar* root = &obj; + while(true){ + root = &std::get((*root)->_native); + if(!(*root)->isType(_tp_super)) break; + } + (*root)->attribs[name] = value; + }else{ + obj->attribs[name] = value; + } } inline void setAttr(PyVar& obj, const _Str& name, PyVar&& value) { - obj->attribs[name] = std::move(value); + if(obj->isType(_tp_super)){ + const PyVar* root = &obj; + while(true){ + root = &std::get((*root)->_native); + if(!(*root)->isType(_tp_super)) break; + } + (*root)->attribs[name] = std::move(value); + }else{ + obj->attribs[name] = std::move(value); + } } void bindMethod(_Str typeName, _Str funcName, _CppFunc fn) { @@ -690,7 +726,7 @@ public: PyVar _tp_list, _tp_tuple; PyVar _tp_function, _tp_native_function, _tp_native_iterator, _tp_bounded_method; PyVar _tp_slice, _tp_range, _tp_module, _tp_pointer; - PyVar _tp_user_pointer; + PyVar _tp_user_pointer, _tp_super; __DEF_PY_POOL(Int, _Int, _tp_int, 256); __DEF_PY_AS_C(Int, _Int, _tp_int) @@ -744,6 +780,7 @@ public: _tp_native_function = newClassType("_native_function"); _tp_native_iterator = newClassType("_native_iterator"); _tp_bounded_method = newClassType("_bounded_method"); + _tp_super = newClassType("super"); this->None = newObject(_types["NoneType"], (_Int)0); this->Ellipsis = newObject(_types["ellipsis"], (_Int)0); diff --git a/tests/class.py b/tests/class.py new file mode 100644 index 00000000..f33ad758 --- /dev/null +++ b/tests/class.py @@ -0,0 +1,67 @@ +class A: + def __init__(self, a, b): + self.a = a + self.b = b + + def add(self): + return self.a + self.b + + def sub(self): + return self.a - self.b + +a = A(1, 2) +assert a.add() == 3 +assert a.sub() == -1 + +assert a.__base__ is object + +class B(A): + def __init__(self, a, b, c): + super().__init__(a, b) + self.c = c + + def add(self): + return self.a + self.b + self.c + + def sub(self): + return self.a - self.b - self.c + +assert B.__base__ is A + +b = B(1, 2, 3) +assert b.add() == 6 +assert b.sub() == -4 + +class C(B): + def __init__(self, a, b, c, d): + super().__init__(a, b, c) + self.d = d + + def add(self): + return self.a + self.b + self.c + self.d + + def sub(self): + return self.a - self.b - self.c - self.d + +assert C.__base__ is B + +c = C(1, 2, 3, 4) +assert c.add() == 10 +assert c.sub() == -8 + +class D(C): + def __init__(self, a, b, c, d, e): + super().__init__(a, b, c, d) + self.e = e + + def add(self): + return super().add() + self.e + + def sub(self): + return super().sub() - self.e + +assert D.__base__ is C + +d = D(1, 2, 3, 4, 5) +assert d.add() == 15 +assert d.sub() == -13 \ No newline at end of file