fix builtin modules and super()

This commit is contained in:
blueloveTH 2025-03-12 14:25:48 +08:00
parent 3afab89ab3
commit 9ff3417621
6 changed files with 59 additions and 51 deletions

View File

@ -8,7 +8,6 @@ MAGIC_METHOD(__ge__)
///////////////////////////// /////////////////////////////
MAGIC_METHOD(__neg__) MAGIC_METHOD(__neg__)
MAGIC_METHOD(__abs__) MAGIC_METHOD(__abs__)
MAGIC_METHOD(__float__)
MAGIC_METHOD(__int__) MAGIC_METHOD(__int__)
MAGIC_METHOD(__round__) MAGIC_METHOD(__round__)
MAGIC_METHOD(__divmod__) MAGIC_METHOD(__divmod__)

View File

@ -171,9 +171,12 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
if(name == __new__) { if(name == __new__) {
// __new__ acts like a @staticmethod // __new__ acts like a @staticmethod
if(py_istype(self, tp_type)) { if(self->type == tp_type) {
// T.__new__(...) // T.__new__(...)
type = py_totype(self); type = py_totype(self);
} else if(self->type == tp_super) {
// super(T, obj).__new__(...)
type = *(py_Type*)py_touserdata(self);
} else { } else {
// invalid usage of `__new__` // invalid usage of `__new__`
return false; return false;
@ -187,12 +190,16 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
return false; return false;
} }
py_TValue self_bak; // to avoid overlapping
// handle super() proxy // handle super() proxy
if(py_istype(self, tp_super)) { if(py_istype(self, tp_super)) {
type = *(py_Type*)py_touserdata(self); type = *(py_Type*)py_touserdata(self);
*self = *py_getslot(self, 0); // BUG: here we modify `self` which refers to the stack directly
// If `pk_loadmethod` fails, `self` will be corrupted
self_bak = *py_getslot(self, 0);
} else { } else {
type = self->type; type = self->type;
self_bak = *self;
} }
py_Ref cls_var = py_tpfindname(type, name); py_Ref cls_var = py_tpfindname(type, name);
@ -200,8 +207,6 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
switch(cls_var->type) { switch(cls_var->type) {
case tp_function: case tp_function:
case tp_nativefunc: { case tp_nativefunc: {
py_TValue self_bak = *self;
// `out` may overlap with `self`. If we assign `out`, `self` may be corrupted.
self[0] = *cls_var; self[0] = *cls_var;
self[1] = self_bak; self[1] = self_bak;
break; break;

View File

@ -283,7 +283,7 @@ static bool dict__init__(int argc, py_Ref argv) {
for(int i = 0; i < length; i++) { for(int i = 0; i < length; i++) {
py_Ref tuple = &p[i]; py_Ref tuple = &p[i];
if(!py_istuple(tuple) || py_tuple_len(tuple) != 2) { if(!py_istuple(tuple) || py_tuple_len(tuple) != 2) {
return TypeError("dict.__init__() argument must be a list of tuple-2"); return ValueError("dict.__init__() argument must be a list of tuple-2");
} }
py_Ref key = py_tuple_getitem(tuple, 0); py_Ref key = py_tuple_getitem(tuple, 0);
py_Ref val = py_tuple_getitem(tuple, 1); py_Ref val = py_tuple_getitem(tuple, 1);

View File

@ -388,7 +388,7 @@ static bool float__new__(int argc, py_Ref argv) {
py_newfloat(py_retval(), float_out); py_newfloat(py_retval(), float_out);
return true; return true;
} }
default: return pk_callmagic(__float__, 1, argv + 1); default: return TypeError("float() argument must be a string or a real number");
} }
} }

View File

@ -7,7 +7,7 @@ class TestSuperBase():
return self.base_attr return self.base_attr
def error(self): def error(self):
raise Exception('未能拦截错误') raise RuntimeError('未能拦截错误')
class TestSuperChild1(TestSuperBase): class TestSuperChild1(TestSuperBase):
@ -20,7 +20,7 @@ class TestSuperChild1(TestSuperBase):
def error_handling(self): def error_handling(self):
try: try:
super(TestSuperChild1, self).error() super(TestSuperChild1, self).error()
except: except RuntimeError:
pass pass
class TestSuperChild2(TestSuperBase): class TestSuperChild2(TestSuperBase):
@ -54,18 +54,22 @@ class TestSuperNoBaseMethod(TestSuperBase):
def __init__(self): def __init__(self):
super(TestSuperNoBaseMethod, self).append(1) super(TestSuperNoBaseMethod, self).append(1)
class TestSuperNoParent():
def method(self):
super(TestSuperNoParent, self).method()
try: try:
t = TestSuperNoParent() t = TestSuperNoParent().method()
print('未能拦截错误') print('未能拦截错误2')
exit(2) exit(2)
except: except AttributeError:
pass pass
try: try:
t = TestSuperNoBaseMethod() t = TestSuperNoBaseMethod()
print('未能拦截错误') print('未能拦截错误3')
exit(3) exit(3)
except: except AttributeError:
pass pass
class B(): class B():
@ -82,17 +86,17 @@ class D():
try: try:
c = C() c = C()
c.method() c.method()
print('未能拦截错误') print('未能拦截错误4')
exit(4) exit(4)
except: except AttributeError:
pass pass
try: try:
d = D() d = D()
d.method() d.method()
print('未能拦截错误') print('未能拦截错误5')
exit(5) exit(5)
except: except TypeError:
pass pass
# test hash: # test hash:
@ -133,16 +137,16 @@ assert type(hash(a)) is int
# 测试不可哈希对象 # 测试不可哈希对象
try: try:
hash({1:1}) hash({1:1})
print('未能拦截错误') print('未能拦截错误6')
exit(6) exit(6)
except: except TypeError:
pass pass
try: try:
hash([1]) hash([1])
print('未能拦截错误') print('未能拦截错误7')
exit(7) exit(7)
except: except TypeError:
pass pass
# test chr # test chr
@ -165,24 +169,24 @@ repr(A())
try: try:
range(1,2,3,4) range(1,2,3,4)
print('未能拦截错误, 在测试 range') print('未能拦截错误8, 在测试 range')
exit(8) exit(8)
except: except TypeError:
pass pass
# /************ int ************/ # /************ int ************/
try: try:
int('asad') int('asad')
print('未能拦截错误, 在测试 int') print('未能拦截错误9, 在测试 int')
exit(9) exit(9)
except: except ValueError:
pass pass
try: try:
int(123, 16) int(123, 16)
print('未能拦截错误, 在测试 int') print('未能拦截错误10, 在测试 int')
exit(10) exit(10)
except: except TypeError:
pass pass
assert type(10//11) is int assert type(10//11) is int
@ -191,16 +195,16 @@ assert type(11%2) is int
try: try:
float('asad') float('asad')
print('未能拦截错误, 在测试 float') print('未能拦截错误11, 在测试 float')
exit(11) exit(11)
except: except ValueError:
pass pass
try: try:
float([]) float([])
print('未能拦截错误, 在测试 float') print('未能拦截错误12, 在测试 float')
exit(12) exit(12)
except: except TypeError:
pass pass
# /************ str ************/ # /************ str ************/
@ -212,10 +216,10 @@ assert type(12 * '12') is str
assert type('25363546'.index('63')) is int assert type('25363546'.index('63')) is int
try: try:
'25363546'.index('err') '25363546'.index('err')
print('未能拦截错误, 在测试 str.index') print('未能拦截错误13, 在测试 str.index')
exit(13) exit(13)
except: except ValueError as e:
pass assert str(e) == "substring not found"
# 未完全测试准确性----------------------------------------------- # 未完全测试准确性-----------------------------------------------
@ -227,9 +231,9 @@ assert '25363546'.find('err') == -1
# /************ list ************/ # /************ list ************/
try: try:
list(1,2) list(1,2)
print('未能拦截错误, 在测试 list') print('未能拦截错误14, 在测试 list')
exit(14) exit(14)
except: except TypeError:
pass pass
# 未完全测试准确性---------------------------------------------- # 未完全测试准确性----------------------------------------------
@ -237,10 +241,10 @@ except:
assert type([1,2,3,4,5].index(4)) is int assert type([1,2,3,4,5].index(4)) is int
try: try:
[1,2,3,4,5].index(6) [1,2,3,4,5].index(6)
print('未能拦截错误, 在测试 list.index') print('未能拦截错误15, 在测试 list.index')
exit(15) exit(15)
except: except ValueError as e:
pass assert str(e) == "list.index(x): x not in list"
@ -248,19 +252,19 @@ except:
# test list.remove: # test list.remove:
try: try:
[1,2,3,4,5].remove(6) [1,2,3,4,5].remove(6)
print('未能拦截错误, 在测试 list.remove') print('未能拦截错误16, 在测试 list.remove')
exit(16) exit(16)
except: except ValueError as e:
pass assert str(e) == "list.remove(x): x not in list"
# 未完全测试准确性---------------------------------------------- # 未完全测试准确性----------------------------------------------
# test list.pop: # test list.pop:
try: try:
[1,2,3,4,5].pop(1,2,3,4) [1,2,3,4,5].pop(1,2,3,4)
print('未能拦截错误, 在测试 list.pop') print('未能拦截错误17, 在测试 list.pop')
exit(17) exit(17)
except: except TypeError:
pass pass
@ -274,9 +278,9 @@ assert type(12 * [12]) is list
# test tuple: # test tuple:
try: try:
tuple(1,2) tuple(1,2)
print('未能拦截错误, 在测试 tuple') print('未能拦截错误18, 在测试 tuple')
exit(18) exit(18)
except: except TypeError:
pass pass
assert [1,2,2,3,3,3].count(3) == 3 assert [1,2,2,3,3,3].count(3) == 3

View File

@ -70,21 +70,21 @@ try:
dict([(1, 2, 3)]) dict([(1, 2, 3)])
print('未能拦截错误, 在测试 dict') print('未能拦截错误, 在测试 dict')
exit(1) exit(1)
except: except ValueError:
pass pass
try: try:
dict([(1, 2)], 1) dict([(1, 2)], 1)
print('未能拦截错误, 在测试 dict') print('未能拦截错误, 在测试 dict')
exit(1) exit(1)
except: except TypeError:
pass pass
try: try:
hash(dict([(1,2)])) hash(dict([(1,2)]))
print('未能拦截错误, 在测试 dict.__hash__') print('未能拦截错误, 在测试 dict.__hash__')
exit(1) exit(1)
except: except TypeError:
pass pass
# test dict.__iter__ # test dict.__iter__
@ -102,7 +102,7 @@ try:
{1:2, 3:4}.get(1,1, 1) {1:2, 3:4}.get(1,1, 1)
print('未能拦截错误, 在测试 dict.get') print('未能拦截错误, 在测试 dict.get')
exit(1) exit(1)
except: except TypeError:
pass pass
# 未完全测试准确性----------------------------------------------- # 未完全测试准确性-----------------------------------------------