fix builtin modules and super()

This commit is contained in:
blueloveTH 2025-03-12 14:25:48 +08:00
parent 3afab89ab3
commit 9ff3417621
6 changed files with 59 additions and 51 deletions

View File

@ -8,7 +8,6 @@ MAGIC_METHOD(__ge__)
/////////////////////////////
MAGIC_METHOD(__neg__)
MAGIC_METHOD(__abs__)
MAGIC_METHOD(__float__)
MAGIC_METHOD(__int__)
MAGIC_METHOD(__round__)
MAGIC_METHOD(__divmod__)

View File

@ -171,9 +171,12 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
if(name == __new__) {
// __new__ acts like a @staticmethod
if(py_istype(self, tp_type)) {
if(self->type == tp_type) {
// T.__new__(...)
type = py_totype(self);
} else if(self->type == tp_super) {
// super(T, obj).__new__(...)
type = *(py_Type*)py_touserdata(self);
} else {
// invalid usage of `__new__`
return false;
@ -187,12 +190,16 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
return false;
}
py_TValue self_bak; // to avoid overlapping
// handle super() proxy
if(py_istype(self, tp_super)) {
type = *(py_Type*)py_touserdata(self);
*self = *py_getslot(self, 0);
// BUG: here we modify `self` which refers to the stack directly
// If `pk_loadmethod` fails, `self` will be corrupted
self_bak = *py_getslot(self, 0);
} else {
type = self->type;
self_bak = *self;
}
py_Ref cls_var = py_tpfindname(type, name);
@ -200,8 +207,6 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
switch(cls_var->type) {
case tp_function:
case tp_nativefunc: {
py_TValue self_bak = *self;
// `out` may overlap with `self`. If we assign `out`, `self` may be corrupted.
self[0] = *cls_var;
self[1] = self_bak;
break;

View File

@ -283,7 +283,7 @@ static bool dict__init__(int argc, py_Ref argv) {
for(int i = 0; i < length; i++) {
py_Ref tuple = &p[i];
if(!py_istuple(tuple) || py_tuple_len(tuple) != 2) {
return TypeError("dict.__init__() argument must be a list of tuple-2");
return ValueError("dict.__init__() argument must be a list of tuple-2");
}
py_Ref key = py_tuple_getitem(tuple, 0);
py_Ref val = py_tuple_getitem(tuple, 1);

View File

@ -388,7 +388,7 @@ static bool float__new__(int argc, py_Ref argv) {
py_newfloat(py_retval(), float_out);
return true;
}
default: return pk_callmagic(__float__, 1, argv + 1);
default: return TypeError("float() argument must be a string or a real number");
}
}

View File

@ -7,7 +7,7 @@ class TestSuperBase():
return self.base_attr
def error(self):
raise Exception('未能拦截错误')
raise RuntimeError('未能拦截错误')
class TestSuperChild1(TestSuperBase):
@ -20,7 +20,7 @@ class TestSuperChild1(TestSuperBase):
def error_handling(self):
try:
super(TestSuperChild1, self).error()
except:
except RuntimeError:
pass
class TestSuperChild2(TestSuperBase):
@ -54,18 +54,22 @@ class TestSuperNoBaseMethod(TestSuperBase):
def __init__(self):
super(TestSuperNoBaseMethod, self).append(1)
class TestSuperNoParent():
def method(self):
super(TestSuperNoParent, self).method()
try:
t = TestSuperNoParent()
print('未能拦截错误')
t = TestSuperNoParent().method()
print('未能拦截错误2')
exit(2)
except:
except AttributeError:
pass
try:
t = TestSuperNoBaseMethod()
print('未能拦截错误')
print('未能拦截错误3')
exit(3)
except:
except AttributeError:
pass
class B():
@ -82,17 +86,17 @@ class D():
try:
c = C()
c.method()
print('未能拦截错误')
print('未能拦截错误4')
exit(4)
except:
except AttributeError:
pass
try:
d = D()
d.method()
print('未能拦截错误')
print('未能拦截错误5')
exit(5)
except:
except TypeError:
pass
# test hash:
@ -133,16 +137,16 @@ assert type(hash(a)) is int
# 测试不可哈希对象
try:
hash({1:1})
print('未能拦截错误')
print('未能拦截错误6')
exit(6)
except:
except TypeError:
pass
try:
hash([1])
print('未能拦截错误')
print('未能拦截错误7')
exit(7)
except:
except TypeError:
pass
# test chr
@ -165,24 +169,24 @@ repr(A())
try:
range(1,2,3,4)
print('未能拦截错误, 在测试 range')
print('未能拦截错误8, 在测试 range')
exit(8)
except:
except TypeError:
pass
# /************ int ************/
try:
int('asad')
print('未能拦截错误, 在测试 int')
print('未能拦截错误9, 在测试 int')
exit(9)
except:
except ValueError:
pass
try:
int(123, 16)
print('未能拦截错误, 在测试 int')
print('未能拦截错误10, 在测试 int')
exit(10)
except:
except TypeError:
pass
assert type(10//11) is int
@ -191,16 +195,16 @@ assert type(11%2) is int
try:
float('asad')
print('未能拦截错误, 在测试 float')
print('未能拦截错误11, 在测试 float')
exit(11)
except:
except ValueError:
pass
try:
float([])
print('未能拦截错误, 在测试 float')
print('未能拦截错误12, 在测试 float')
exit(12)
except:
except TypeError:
pass
# /************ str ************/
@ -212,10 +216,10 @@ assert type(12 * '12') is str
assert type('25363546'.index('63')) is int
try:
'25363546'.index('err')
print('未能拦截错误, 在测试 str.index')
print('未能拦截错误13, 在测试 str.index')
exit(13)
except:
pass
except ValueError as e:
assert str(e) == "substring not found"
# 未完全测试准确性-----------------------------------------------
@ -227,9 +231,9 @@ assert '25363546'.find('err') == -1
# /************ list ************/
try:
list(1,2)
print('未能拦截错误, 在测试 list')
print('未能拦截错误14, 在测试 list')
exit(14)
except:
except TypeError:
pass
# 未完全测试准确性----------------------------------------------
@ -237,10 +241,10 @@ except:
assert type([1,2,3,4,5].index(4)) is int
try:
[1,2,3,4,5].index(6)
print('未能拦截错误, 在测试 list.index')
print('未能拦截错误15, 在测试 list.index')
exit(15)
except:
pass
except ValueError as e:
assert str(e) == "list.index(x): x not in list"
@ -248,19 +252,19 @@ except:
# test list.remove:
try:
[1,2,3,4,5].remove(6)
print('未能拦截错误, 在测试 list.remove')
print('未能拦截错误16, 在测试 list.remove')
exit(16)
except:
pass
except ValueError as e:
assert str(e) == "list.remove(x): x not in list"
# 未完全测试准确性----------------------------------------------
# test list.pop:
try:
[1,2,3,4,5].pop(1,2,3,4)
print('未能拦截错误, 在测试 list.pop')
print('未能拦截错误17, 在测试 list.pop')
exit(17)
except:
except TypeError:
pass
@ -274,9 +278,9 @@ assert type(12 * [12]) is list
# test tuple:
try:
tuple(1,2)
print('未能拦截错误, 在测试 tuple')
print('未能拦截错误18, 在测试 tuple')
exit(18)
except:
except TypeError:
pass
assert [1,2,2,3,3,3].count(3) == 3

View File

@ -70,21 +70,21 @@ try:
dict([(1, 2, 3)])
print('未能拦截错误, 在测试 dict')
exit(1)
except:
except ValueError:
pass
try:
dict([(1, 2)], 1)
print('未能拦截错误, 在测试 dict')
exit(1)
except:
except TypeError:
pass
try:
hash(dict([(1,2)]))
print('未能拦截错误, 在测试 dict.__hash__')
exit(1)
except:
except TypeError:
pass
# test dict.__iter__
@ -102,7 +102,7 @@ try:
{1:2, 3:4}.get(1,1, 1)
print('未能拦截错误, 在测试 dict.get')
exit(1)
except:
except TypeError:
pass
# 未完全测试准确性-----------------------------------------------