diff --git a/include/pocketpy/xmacros/magics.h b/include/pocketpy/xmacros/magics.h index 2fc26836..e9bd603e 100644 --- a/include/pocketpy/xmacros/magics.h +++ b/include/pocketpy/xmacros/magics.h @@ -59,6 +59,9 @@ MAGIC_METHOD(__package__) MAGIC_METHOD(__path__) MAGIC_METHOD(__class__) MAGIC_METHOD(__abs__) +MAGIC_METHOD(__float__) +MAGIC_METHOD(__int__) +MAGIC_METHOD(__round__) MAGIC_METHOD(__getattr__) MAGIC_METHOD(__missing__) diff --git a/src/public/modules.c b/src/public/modules.c index 11037057..3674b76e 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -287,22 +287,23 @@ static bool builtins_round(int argc, py_Ref argv) { return TypeError("round() takes 1 or 2 arguments"); } - if(py_isint(py_arg(0))) { + if(argv->type == tp_int) { py_assign(py_retval(), py_arg(0)); return true; - } - - PY_CHECK_ARG_TYPE(0, tp_float); - py_f64 x = py_tofloat(py_arg(0)); - py_f64 offset = x >= 0 ? 0.5 : -0.5; - if(ndigits == -1) { - py_newint(py_retval(), (py_i64)(x + offset)); + } else if(argv->type == tp_float) { + PY_CHECK_ARG_TYPE(0, tp_float); + py_f64 x = py_tofloat(py_arg(0)); + py_f64 offset = x >= 0 ? 0.5 : -0.5; + if(ndigits == -1) { + py_newint(py_retval(), (py_i64)(x + offset)); + return true; + } + py_f64 factor = pow(10, ndigits); + py_newfloat(py_retval(), (py_i64)(x * factor + offset) / factor); return true; } - py_f64 factor = pow(10, ndigits); - py_newfloat(py_retval(), (py_i64)(x * factor + offset) / factor); - return true; + return pk_callmagic(__round__, argc, argv); } static bool builtins_print(int argc, py_Ref argv) { @@ -442,12 +443,11 @@ static bool builtins_ord(int argc, py_Ref argv) { PY_CHECK_ARG_TYPE(0, tp_str); c11_sv sv = py_tosv(py_arg(0)); if(c11_sv__u8_length(sv) != 1) { - return TypeError("ord() expected a character, but string of length %d found", c11_sv__u8_length(sv)); + return TypeError("ord() expected a character, but string of length %d found", + c11_sv__u8_length(sv)); } int u8bytes = c11__u8_header(sv.data[0], true); - if (u8bytes == 0) { - return ValueError("invalid char: %c", sv.data[0]); - } + if(u8bytes == 0) { return ValueError("invalid char: %c", sv.data[0]); } int value = c11__u8_value(u8bytes, sv.data); py_newint(py_retval(), value); return true; diff --git a/src/public/py_number.c b/src/public/py_number.c index 5a89315f..e278e309 100644 --- a/src/public/py_number.c +++ b/src/public/py_number.c @@ -297,7 +297,7 @@ static bool int__new__(int argc, py_Ref argv) { return true; } case tp_str: break; // leave to the next block - default: return TypeError("invalid arguments for int()"); + default: return pk_callmagic(__int__, 1, argv+1); } } // 2+ args -> error @@ -350,26 +350,27 @@ static bool float__new__(int argc, py_Ref argv) { py_newfloat(py_retval(), py_tobool(&argv[1])); return true; } - case tp_str: break; // leave to the next block - default: return TypeError("invalid arguments for float()"); - } - // str to float - c11_sv sv = py_tosv(py_arg(1)); + case tp_str: { + // str to float + c11_sv sv = py_tosv(py_arg(1)); - if(c11__sveq2(sv, "inf")) { - py_newfloat(py_retval(), INFINITY); - return true; - } - if(c11__sveq2(sv, "-inf")) { - py_newfloat(py_retval(), -INFINITY); - return true; - } + if(c11__sveq2(sv, "inf")) { + py_newfloat(py_retval(), INFINITY); + return true; + } + if(c11__sveq2(sv, "-inf")) { + py_newfloat(py_retval(), -INFINITY); + return true; + } - char* p_end; - py_f64 float_out = strtod(sv.data, &p_end); - if(p_end != sv.data + sv.size) return ValueError("invalid literal for float(): %q", sv); - py_newfloat(py_retval(), float_out); - return true; + char* p_end; + py_f64 float_out = strtod(sv.data, &p_end); + if(p_end != sv.data + sv.size) return ValueError("invalid literal for float(): %q", sv); + py_newfloat(py_retval(), float_out); + return true; + } + default: return pk_callmagic(__float__, 1, argv+1); + } } // tp_bool diff --git a/tests/99_extras.py b/tests/99_extras.py index 299ddaf8..0f421f8b 100644 --- a/tests/99_extras.py +++ b/tests/99_extras.py @@ -51,3 +51,28 @@ assert A()[1:2, :A()[3:4, ::-1]] == (slice(1, 2, None), slice(None, (slice(3, 4, # test right associative assert 2**2**3 == 256 assert (2**2**3)**2 == 65536 + +class Number: + def __float__(self): + return 1.0 + + def __int__(self): + return 2 + + def __divmod__(self, other): + return 3, 4 + + def __round__(self, *args): + return args + +assert divmod(Number(), 0) == (3, 4) +assert float(Number()) == 1.0 +assert int(Number()) == 2 + +assert round(Number()) == tuple() +assert round(Number(), 1) == (1,) + + + + +