diff --git a/include/pocketpy/xmacros/magics.h b/include/pocketpy/xmacros/magics.h index 966ffb22..f4941a2e 100644 --- a/include/pocketpy/xmacros/magics.h +++ b/include/pocketpy/xmacros/magics.h @@ -8,7 +8,6 @@ MAGIC_METHOD(__ge__) ///////////////////////////// MAGIC_METHOD(__neg__) MAGIC_METHOD(__abs__) -MAGIC_METHOD(__float__) MAGIC_METHOD(__int__) MAGIC_METHOD(__round__) MAGIC_METHOD(__divmod__) diff --git a/src/public/internal.c b/src/public/internal.c index 7037409f..d0741faa 100644 --- a/src/public/internal.c +++ b/src/public/internal.c @@ -171,9 +171,12 @@ bool pk_loadmethod(py_StackRef self, py_Name name) { if(name == __new__) { // __new__ acts like a @staticmethod - if(py_istype(self, tp_type)) { + if(self->type == tp_type) { // T.__new__(...) type = py_totype(self); + } else if(self->type == tp_super) { + // super(T, obj).__new__(...) + type = *(py_Type*)py_touserdata(self); } else { // invalid usage of `__new__` return false; @@ -187,12 +190,16 @@ bool pk_loadmethod(py_StackRef self, py_Name name) { return false; } + py_TValue self_bak; // to avoid overlapping // handle super() proxy if(py_istype(self, tp_super)) { type = *(py_Type*)py_touserdata(self); - *self = *py_getslot(self, 0); + // BUG: here we modify `self` which refers to the stack directly + // If `pk_loadmethod` fails, `self` will be corrupted + self_bak = *py_getslot(self, 0); } else { type = self->type; + self_bak = *self; } py_Ref cls_var = py_tpfindname(type, name); @@ -200,8 +207,6 @@ bool pk_loadmethod(py_StackRef self, py_Name name) { switch(cls_var->type) { case tp_function: case tp_nativefunc: { - py_TValue self_bak = *self; - // `out` may overlap with `self`. If we assign `out`, `self` may be corrupted. self[0] = *cls_var; self[1] = self_bak; break; diff --git a/src/public/py_dict.c b/src/public/py_dict.c index 4ab786c8..a4569151 100644 --- a/src/public/py_dict.c +++ b/src/public/py_dict.c @@ -283,7 +283,7 @@ static bool dict__init__(int argc, py_Ref argv) { for(int i = 0; i < length; i++) { py_Ref tuple = &p[i]; if(!py_istuple(tuple) || py_tuple_len(tuple) != 2) { - return TypeError("dict.__init__() argument must be a list of tuple-2"); + return ValueError("dict.__init__() argument must be a list of tuple-2"); } py_Ref key = py_tuple_getitem(tuple, 0); py_Ref val = py_tuple_getitem(tuple, 1); diff --git a/src/public/py_number.c b/src/public/py_number.c index 6b2ee092..164b9009 100644 --- a/src/public/py_number.c +++ b/src/public/py_number.c @@ -388,7 +388,7 @@ static bool float__new__(int argc, py_Ref argv) { py_newfloat(py_retval(), float_out); return true; } - default: return pk_callmagic(__float__, 1, argv + 1); + default: return TypeError("float() argument must be a string or a real number"); } } diff --git a/tests/77_builtin_func_1.py b/tests/77_builtin_func_1.py index cf7c4ba3..622a8eea 100644 --- a/tests/77_builtin_func_1.py +++ b/tests/77_builtin_func_1.py @@ -7,7 +7,7 @@ class TestSuperBase(): return self.base_attr def error(self): - raise Exception('未能拦截错误') + raise RuntimeError('未能拦截错误') class TestSuperChild1(TestSuperBase): @@ -20,7 +20,7 @@ class TestSuperChild1(TestSuperBase): def error_handling(self): try: super(TestSuperChild1, self).error() - except: + except RuntimeError: pass class TestSuperChild2(TestSuperBase): @@ -54,18 +54,22 @@ class TestSuperNoBaseMethod(TestSuperBase): def __init__(self): super(TestSuperNoBaseMethod, self).append(1) +class TestSuperNoParent(): + def method(self): + super(TestSuperNoParent, self).method() + try: - t = TestSuperNoParent() - print('未能拦截错误') + t = TestSuperNoParent().method() + print('未能拦截错误2') exit(2) -except: +except AttributeError: pass try: t = TestSuperNoBaseMethod() - print('未能拦截错误') + print('未能拦截错误3') exit(3) -except: +except AttributeError: pass class B(): @@ -82,17 +86,17 @@ class D(): try: c = C() c.method() - print('未能拦截错误') + print('未能拦截错误4') exit(4) -except: +except AttributeError: pass try: d = D() d.method() - print('未能拦截错误') + print('未能拦截错误5') exit(5) -except: +except TypeError: pass # test hash: @@ -133,16 +137,16 @@ assert type(hash(a)) is int # 测试不可哈希对象 try: hash({1:1}) - print('未能拦截错误') + print('未能拦截错误6') exit(6) -except: +except TypeError: pass try: hash([1]) - print('未能拦截错误') + print('未能拦截错误7') exit(7) -except: +except TypeError: pass # test chr @@ -165,24 +169,24 @@ repr(A()) try: range(1,2,3,4) - print('未能拦截错误, 在测试 range') + print('未能拦截错误8, 在测试 range') exit(8) -except: +except TypeError: pass # /************ int ************/ try: int('asad') - print('未能拦截错误, 在测试 int') + print('未能拦截错误9, 在测试 int') exit(9) -except: +except ValueError: pass try: int(123, 16) - print('未能拦截错误, 在测试 int') + print('未能拦截错误10, 在测试 int') exit(10) -except: +except TypeError: pass assert type(10//11) is int @@ -191,16 +195,16 @@ assert type(11%2) is int try: float('asad') - print('未能拦截错误, 在测试 float') + print('未能拦截错误11, 在测试 float') exit(11) -except: +except ValueError: pass try: float([]) - print('未能拦截错误, 在测试 float') + print('未能拦截错误12, 在测试 float') exit(12) -except: +except TypeError: pass # /************ str ************/ @@ -212,10 +216,10 @@ assert type(12 * '12') is str assert type('25363546'.index('63')) is int try: '25363546'.index('err') - print('未能拦截错误, 在测试 str.index') + print('未能拦截错误13, 在测试 str.index') exit(13) -except: - pass +except ValueError as e: + assert str(e) == "substring not found" # 未完全测试准确性----------------------------------------------- @@ -227,9 +231,9 @@ assert '25363546'.find('err') == -1 # /************ list ************/ try: list(1,2) - print('未能拦截错误, 在测试 list') + print('未能拦截错误14, 在测试 list') exit(14) -except: +except TypeError: pass # 未完全测试准确性---------------------------------------------- @@ -237,10 +241,10 @@ except: assert type([1,2,3,4,5].index(4)) is int try: [1,2,3,4,5].index(6) - print('未能拦截错误, 在测试 list.index') + print('未能拦截错误15, 在测试 list.index') exit(15) -except: - pass +except ValueError as e: + assert str(e) == "list.index(x): x not in list" @@ -248,19 +252,19 @@ except: # test list.remove: try: [1,2,3,4,5].remove(6) - print('未能拦截错误, 在测试 list.remove') + print('未能拦截错误16, 在测试 list.remove') exit(16) -except: - pass +except ValueError as e: + assert str(e) == "list.remove(x): x not in list" # 未完全测试准确性---------------------------------------------- # test list.pop: try: [1,2,3,4,5].pop(1,2,3,4) - print('未能拦截错误, 在测试 list.pop') + print('未能拦截错误17, 在测试 list.pop') exit(17) -except: +except TypeError: pass @@ -274,9 +278,9 @@ assert type(12 * [12]) is list # test tuple: try: tuple(1,2) - print('未能拦截错误, 在测试 tuple') + print('未能拦截错误18, 在测试 tuple') exit(18) -except: +except TypeError: pass assert [1,2,2,3,3,3].count(3) == 3 diff --git a/tests/77_builtin_func_2.py b/tests/77_builtin_func_2.py index 46b28a0f..b31c4cfd 100644 --- a/tests/77_builtin_func_2.py +++ b/tests/77_builtin_func_2.py @@ -70,21 +70,21 @@ try: dict([(1, 2, 3)]) print('未能拦截错误, 在测试 dict') exit(1) -except: +except ValueError: pass try: dict([(1, 2)], 1) print('未能拦截错误, 在测试 dict') exit(1) -except: +except TypeError: pass try: hash(dict([(1,2)])) print('未能拦截错误, 在测试 dict.__hash__') exit(1) -except: +except TypeError: pass # test dict.__iter__ @@ -102,7 +102,7 @@ try: {1:2, 3:4}.get(1,1, 1) print('未能拦截错误, 在测试 dict.get') exit(1) -except: +except TypeError: pass # 未完全测试准确性-----------------------------------------------