diff --git a/python/builtins.py b/python/builtins.py index 805ae4de..172b6e67 100644 --- a/python/builtins.py +++ b/python/builtins.py @@ -4,15 +4,6 @@ def print(*args, sep=' ', end='\n'): s = sep.join([str(i) for i in args]) _sys.stdout.write(s + end) -def round(x, ndigits=0): - assert ndigits >= 0 - if ndigits == 0: - return int(x + 0.5) if x >= 0 else int(x - 0.5) - if x >= 0: - return int(x * 10**ndigits + 0.5) / 10**ndigits - else: - return int(x * 10**ndigits - 0.5) / 10**ndigits - def abs(x): return -x if x < 0 else x diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index 40891505..c8a96b34 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -123,6 +123,28 @@ void init_builtins(VM* _vm) { return VAR(MappingProxy(mod)); }); + // def round(x, ndigits=0): + // assert ndigits >= 0 + // if ndigits == 0: + // return int(x + 0.5) if x >= 0 else int(x - 0.5) + // if x >= 0: + // return int(x * 10**ndigits + 0.5) / 10**ndigits + // else: + // return int(x * 10**ndigits - 0.5) / 10**ndigits + _vm->bind(_vm->builtins, "round(x, ndigits=0)", [](VM* vm, ArgsView args) { + f64 x = CAST(f64, args[0]); + int ndigits = CAST(int, args[1]); + if(ndigits == 0){ + return x >= 0 ? VAR((i64)(x + 0.5)) : VAR((i64)(x - 0.5)); + } + if(ndigits < 0) vm->ValueError("ndigits should be non-negative"); + if(x >= 0){ + return VAR((i64)(x * std::pow(10, ndigits) + 0.5) / std::pow(10, ndigits)); + }else{ + return VAR((i64)(x * std::pow(10, ndigits) - 0.5) / std::pow(10, ndigits)); + } + }); + _vm->bind_builtin_func<3>("pow", [](VM* vm, ArgsView args) { i64 lhs = CAST(i64, args[0]); // assume lhs>=0 i64 rhs = CAST(i64, args[1]); // assume rhs>=0 diff --git a/tests/70_builtins.py b/tests/70_builtins.py index 9b7578d1..ed703c5f 100644 --- a/tests/70_builtins.py +++ b/tests/70_builtins.py @@ -2,6 +2,21 @@ assert round(23.2) == 23 assert round(23.8) == 24 assert round(-23.2) == -23 assert round(-23.8) == -24 +# round with precision +assert round(23.2, 1) == 23.2 +assert round(23.8, 1) == 23.8 +assert round(-23.2, 1) == -23.2 +assert round(-23.8, 1) == -23.8 +assert round(3.14159, 4) == 3.1416 +assert round(3.14159, 3) == 3.142 +assert round(3.14159, 2) == 3.14 +assert round(3.14159, 1) == 3.1 +assert round(3.14159, 0) == 3 +assert round(-3.14159, 4) == -3.1416 +assert round(-3.14159, 3) == -3.142 +assert round(-3.14159, 2) == -3.14 +assert round(-3.14159, 1) == -3.1 +assert round(-3.14159, 0) == -3 a = [1,2,3,-1] assert sorted(a) == [-1,1,2,3]