add super()

This commit is contained in:
blueloveTH 2022-11-24 02:40:30 +08:00
parent d9cf73a6d3
commit 5412186365
4 changed files with 119 additions and 8 deletions

View File

@ -73,7 +73,7 @@ public:
};
typedef std::shared_ptr<Function> _Func;
typedef std::variant<_Int,_Float,bool,_Str,PyVarList,_CppFunc,_Func,std::shared_ptr<_Iterator>,_BoundedMethod,_Range,_Slice,_Pointer> _Value;
typedef std::variant<PyVar,_Int,_Float,bool,_Str,PyVarList,_CppFunc,_Func,std::shared_ptr<_Iterator>,_BoundedMethod,_Range,_Slice,_Pointer> _Value;
const int _SIZEOF_VALUE = sizeof(_Value);

View File

@ -56,6 +56,13 @@ void __initializeBuiltinFunctions(VM* _vm) {
return vm->PyStr(tvm->readStdin());
});
_vm->bindBuiltinFunc("super", [](VM* vm, const pkpy::ArgList& args) {
vm->__checkArgSize(args, 0);
auto it = vm->topFrame()->f_locals.find("self"_c);
if(it == vm->topFrame()->f_locals.end()) vm->typeError("super() can only be called in a class method");
return vm->newObject(vm->_tp_super, it->second);
});
_vm->bindBuiltinFunc("eval", [](VM* vm, const pkpy::ArgList& args) {
vm->__checkArgSize(args, 1);
const _Str& expr = vm->PyStr_AS_C(args[0]);

View File

@ -588,16 +588,34 @@ public:
}
PyVarOrNull getAttr(const PyVar& obj, const _Str& name, bool throw_err=true) {
auto it = obj->attribs.find(name);
if(it != obj->attribs.end()) return it->second;
PyVarDict::iterator it;
PyObject* cls;
if(obj->isType(_tp_super)){
const PyVar* root = &obj;
int depth = 1;
while(true){
root = &std::get<PyVar>((*root)->_native);
if(!(*root)->isType(_tp_super)) break;
depth++;
}
cls = (*root)->_type.get();
for(int i=0; i<depth; i++) cls = cls->attribs[__base__].get();
it = (*root)->attribs.find(name);
if(it != (*root)->attribs.end()) return it->second;
}else{
it = obj->attribs.find(name);
if(it != obj->attribs.end()) return it->second;
cls = obj->_type.get();
}
PyObject* cls = obj->_type.get();
while(cls != None.get()) {
it = cls->attribs.find(name);
if(it != cls->attribs.end()){
PyVar valueFromCls = it->second;
if(valueFromCls->isType(_tp_function) || valueFromCls->isType(_tp_native_function)){
return PyBoundedMethod({obj, valueFromCls});
return PyBoundedMethod({obj, std::move(valueFromCls)});
}else{
return valueFromCls;
}
@ -609,11 +627,29 @@ public:
}
inline void setAttr(PyVar& obj, const _Str& name, const PyVar& value) {
obj->attribs[name] = value;
if(obj->isType(_tp_super)){
const PyVar* root = &obj;
while(true){
root = &std::get<PyVar>((*root)->_native);
if(!(*root)->isType(_tp_super)) break;
}
(*root)->attribs[name] = value;
}else{
obj->attribs[name] = value;
}
}
inline void setAttr(PyVar& obj, const _Str& name, PyVar&& value) {
obj->attribs[name] = std::move(value);
if(obj->isType(_tp_super)){
const PyVar* root = &obj;
while(true){
root = &std::get<PyVar>((*root)->_native);
if(!(*root)->isType(_tp_super)) break;
}
(*root)->attribs[name] = std::move(value);
}else{
obj->attribs[name] = std::move(value);
}
}
void bindMethod(_Str typeName, _Str funcName, _CppFunc fn) {
@ -690,7 +726,7 @@ public:
PyVar _tp_list, _tp_tuple;
PyVar _tp_function, _tp_native_function, _tp_native_iterator, _tp_bounded_method;
PyVar _tp_slice, _tp_range, _tp_module, _tp_pointer;
PyVar _tp_user_pointer;
PyVar _tp_user_pointer, _tp_super;
__DEF_PY_POOL(Int, _Int, _tp_int, 256);
__DEF_PY_AS_C(Int, _Int, _tp_int)
@ -744,6 +780,7 @@ public:
_tp_native_function = newClassType("_native_function");
_tp_native_iterator = newClassType("_native_iterator");
_tp_bounded_method = newClassType("_bounded_method");
_tp_super = newClassType("super");
this->None = newObject(_types["NoneType"], (_Int)0);
this->Ellipsis = newObject(_types["ellipsis"], (_Int)0);

67
tests/class.py Normal file
View File

@ -0,0 +1,67 @@
class A:
def __init__(self, a, b):
self.a = a
self.b = b
def add(self):
return self.a + self.b
def sub(self):
return self.a - self.b
a = A(1, 2)
assert a.add() == 3
assert a.sub() == -1
assert a.__base__ is object
class B(A):
def __init__(self, a, b, c):
super().__init__(a, b)
self.c = c
def add(self):
return self.a + self.b + self.c
def sub(self):
return self.a - self.b - self.c
assert B.__base__ is A
b = B(1, 2, 3)
assert b.add() == 6
assert b.sub() == -4
class C(B):
def __init__(self, a, b, c, d):
super().__init__(a, b, c)
self.d = d
def add(self):
return self.a + self.b + self.c + self.d
def sub(self):
return self.a - self.b - self.c - self.d
assert C.__base__ is B
c = C(1, 2, 3, 4)
assert c.add() == 10
assert c.sub() == -8
class D(C):
def __init__(self, a, b, c, d, e):
super().__init__(a, b, c, d)
self.e = e
def add(self):
return super().add() + self.e
def sub(self):
return super().sub() - self.e
assert D.__base__ is C
d = D(1, 2, 3, 4, 5)
assert d.add() == 15
assert d.sub() == -13