diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index 1b407038..d0d806e6 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -399,31 +399,37 @@ void init_builtins(VM* _vm) { if(args.size() == 1+0) return VAR(0); // 1 arg if(args.size() == 1+1){ - if (is_type(args[1], vm->tp_float)) return VAR((i64)CAST(f64, args[1])); - if (is_type(args[1], vm->tp_int)) return args[1]; - if (is_type(args[1], vm->tp_bool)) return VAR(_CAST(bool, args[1]) ? 1 : 0); + switch(vm->_tp(args[1]).index){ + case VM::tp_float.index: + return VAR((i64)_CAST(f64, args[1])); + case VM::tp_int.index: + return args[1]; + case VM::tp_bool.index: + return VAR(args[1]==vm->True ? 1 : 0); + case VM::tp_str.index: + break; + default: + vm->TypeError("invalid arguments for int()"); + } } + // 2+ args -> error if(args.size() > 1+2) vm->TypeError("int() takes at most 2 arguments"); - // 2 args - if (is_type(args[1], vm->tp_str)) { - int base = 10; - if(args.size() == 1+2) base = CAST(i64, args[2]); - const Str& s = CAST(Str&, args[1]); - std::string_view sv = s.sv(); - bool negative = false; - if(!sv.empty() && (sv[0] == '+' || sv[0] == '-')){ - negative = sv[0] == '-'; - sv.remove_prefix(1); - } - i64 val; - if(parse_int(sv, &val, base) != IntParsingResult::Success){ - vm->ValueError(_S("invalid literal for int() with base ", base, ": ", s.escape())); - } - if(negative) val = -val; - return VAR(val); + // 1 or 2 args with str + int base = 10; + if(args.size() == 1+2) base = CAST(i64, args[2]); + const Str& s = CAST(Str&, args[1]); + std::string_view sv = s.sv(); + bool negative = false; + if(!sv.empty() && (sv[0] == '+' || sv[0] == '-')){ + negative = sv[0] == '-'; + sv.remove_prefix(1); } - vm->TypeError("invalid arguments for int()"); - return vm->None; + i64 val; + if(parse_int(sv, &val, base) != IntParsingResult::Success){ + vm->ValueError(_S("invalid literal for int() with base ", base, ": ", s.escape())); + } + if(negative) val = -val; + return VAR(val); }); _vm->bind__floordiv__(VM::tp_int, [](VM* vm, PyObject* _0, PyObject* _1) { @@ -460,26 +466,32 @@ void init_builtins(VM* _vm) { if(args.size() == 1+0) return VAR(0.0); if(args.size() > 1+1) vm->TypeError("float() takes at most 1 argument"); // 1 arg - if (is_type(args[1], vm->tp_int)) return VAR((f64)CAST(i64, args[1])); - if (is_type(args[1], vm->tp_float)) return args[1]; - if (is_type(args[1], vm->tp_bool)) return VAR(_CAST(bool, args[1]) ? 1.0 : 0.0); - if (is_type(args[1], vm->tp_str)) { - const Str& s = CAST(Str&, args[1]); - if(s == "inf") return VAR(INFINITY); - if(s == "-inf") return VAR(-INFINITY); - - double float_out; - char* p_end; - try{ - float_out = std::strtod(s.data, &p_end); - PK_ASSERT(p_end == s.end()); - }catch(...){ - vm->ValueError("invalid literal for float(): " + s.escape()); - } - return VAR(float_out); + switch(vm->_tp(args[1]).index){ + case VM::tp_int.index: + return VAR((f64)CAST(i64, args[1])); + case VM::tp_float.index: + return args[1]; + case VM::tp_bool.index: + return VAR(args[1]==vm->True ? 1.0 : 0.0); + case VM::tp_str.index: + break; + default: + vm->TypeError("invalid arguments for float()"); } - vm->TypeError("invalid arguments for float()"); - return vm->None; + // str to float + const Str& s = PK_OBJ_GET(Str, args[1]); + if(s == "inf") return VAR(INFINITY); + if(s == "-inf") return VAR(-INFINITY); + + double float_out; + char* p_end; + try{ + float_out = std::strtod(s.data, &p_end); + PK_ASSERT(p_end == s.end()); + }catch(...){ + vm->ValueError("invalid literal for float(): " + s.escape()); + } + return VAR(float_out); }); _vm->bind__hash__(VM::tp_float, [](VM* vm, PyObject* _0) { diff --git a/tests/01_int.py b/tests/01_int.py index ce96102c..a1cb61a3 100644 --- a/tests/01_int.py +++ b/tests/01_int.py @@ -53,6 +53,10 @@ assert str(1) == '1' assert repr(1) == '1' # test int() +assert int() == 0 +assert int(True) == 1 +assert int(False) == 0 + assert int(1) == 1 assert int(1.0) == 1 assert int(1.1) == 1 diff --git a/tests/02_float.py b/tests/02_float.py index 2b2587c3..07d972b8 100644 --- a/tests/02_float.py +++ b/tests/02_float.py @@ -36,6 +36,12 @@ assert str(1.0) == '1.0' assert repr(1.0) == '1.0' # test float() +assert float() == 0.0 +assert float(True) == 1.0 +assert float(False) == 0.0 +assert float(1) == 1.0 +assert float(-2) == -2.0 + assert eq(float(1), 1.0) assert eq(float(1.0), 1.0) assert eq(float(1.1), 1.1) diff --git a/tests/99_builtin_func.py b/tests/99_builtin_func.py index fdda3383..afe89381 100644 --- a/tests/99_builtin_func.py +++ b/tests/99_builtin_func.py @@ -160,22 +160,6 @@ class A(): repr(A()) - -# 未完全测试准确性----------------------------------------------- -# 33600: 318: _vm->bind_constructor<-1>("range", [](VM* vm, ArgsView args) { -# 16742: 319: args._begin += 1; // skip cls -# 16742: 320: Range r; -# 16742: 321: switch (args.size()) { -# 8735: 322: case 1: r.stop = CAST(i64, args[0]); break; -# 3867: 323: case 2: r.start = CAST(i64, args[0]); r.stop = CAST(i64, args[1]); break; -# 4140: 324: case 3: r.start = CAST(i64, args[0]); r.stop = CAST(i64, args[1]); r.step = CAST(i64, args[2]); break; -# #####: 325: default: vm->TypeError("expected 1-3 arguments, got " + std::to_string(args.size())); -# #####: 326: } -# 33484: 327: return VAR(r); -# 16742: 328: }); -# -: 329: -# test range: - try: range(1,2,3,4) print('未能拦截错误, 在测试 range')