From e227910dbd3b48686b06053c032c12921d4e64ba Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Mon, 2 Jan 2023 16:45:04 +0800 Subject: [PATCH] some fix --- src/__stl__.h | 2 +- src/builtins.h | 24 +++++++++--------------- src/compiler.h | 6 ++++-- src/hash_table8.hpp | 8 +------- src/main.cpp | 4 ++-- src/pocketpy.h | 10 ++++++++++ src/safestl.h | 6 +----- src/str.h | 3 +-- tests/_basic.py | 11 +++++++++++ tests/_functions.py | 11 +++++++++++ 10 files changed, 51 insertions(+), 34 deletions(-) diff --git a/src/__stl__.h b/src/__stl__.h index 34276032..ffcc79f0 100644 --- a/src/__stl__.h +++ b/src/__stl__.h @@ -4,6 +4,7 @@ #pragma warning (disable:4267) #pragma warning (disable:4101) #define _CRT_NONSTDC_NO_DEPRECATE +#define strdup _strdup #endif #include @@ -18,7 +19,6 @@ #include #include #include -#include #include #include diff --git a/src/builtins.h b/src/builtins.h index c556e0c6..9168c186 100644 --- a/src/builtins.h +++ b/src/builtins.h @@ -1,33 +1,27 @@ #pragma once const char* __BUILTINS_CODE = R"( -def len(x): - return x.__len__() - def print(*args, sep=' ', end='\n'): s = sep.join([str(i) for i in args]) __sys_stdout_write(s + end) -def round(x): +def round(x, ndigits=0): + assert ndigits >= 0 + if ndigits == 0: + return x >= 0 ? int(x + 0.5) : int(x - 0.5) if x >= 0: - return int(x + 0.5) + return int(x * 10**ndigits + 0.5) / 10**ndigits else: - return int(x - 0.5) + return int(x * 10**ndigits - 0.5) / 10**ndigits def abs(x): - if x < 0: - return -x - return x + return x < 0 ? -x : x def max(a, b): - if a > b: - return a - return b + return a > b ? a : b def min(a, b): - if a < b: - return a - return b + return a < b ? a : b def sum(iterable): res = 0 diff --git a/src/compiler.h b/src/compiler.h index 42286cae..e3aed30d 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -883,8 +883,10 @@ __LISTCOMP: emitCode(OP_DELETE_REF); consumeEndStatement(); } else if(match(TK("global"))){ - consume(TK("@id")); - getCode()->co_global_names.push_back(parser->previous.str()); + do { + consume(TK("@id")); + getCode()->co_global_names.push_back(parser->previous.str()); + } while (match(TK(","))); consumeEndStatement(); } else if(match(TK("pass"))){ consumeEndStatement(); diff --git a/src/hash_table8.hpp b/src/hash_table8.hpp index acba4298..eb609cf0 100644 --- a/src/hash_table8.hpp +++ b/src/hash_table8.hpp @@ -25,11 +25,6 @@ #pragma once -// Modification: -// 1. Add #define EMH_WYHASH_HASH 1 -// 2. Add static for wymix -#define EMH_WYHASH_HASH 1 - #include #include #include @@ -1665,7 +1660,7 @@ one-way search strategy. #if EMH_WYHASH_HASH //#define WYHASH_CONDOM 1 - inline static uint64_t wymix(uint64_t A, uint64_t B) + inline uint64_t wymix(uint64_t A, uint64_t B) { #if defined(__SIZEOF_INT128__) __uint128_t r = A; r *= B; @@ -1791,4 +1786,3 @@ private: size_type _etail; }; } // namespace emhash - diff --git a/src/main.cpp b/src/main.cpp index bd292e06..cde1830c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -3,8 +3,8 @@ #include "pocketpy.h" -#define PK_DEBUG_TIME -//#define PK_DEBUG_THREADED +//#define PK_DEBUG_TIME +#define PK_DEBUG_THREADED struct Timer{ const char* title; diff --git a/src/pocketpy.h b/src/pocketpy.h index 66a63ee1..02f2bab5 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -85,6 +85,11 @@ void __initializeBuiltinFunctions(VM* _vm) { return vm->PyInt(vm->hash(args[0])); }); + _vm->bindBuiltinFunc("len", [](VM* vm, const pkpy::ArgList& args) { + vm->__checkArgSize(args, 1); + return vm->call(args[0], __len__, pkpy::noArg()); + }); + _vm->bindBuiltinFunc("chr", [](VM* vm, const pkpy::ArgList& args) { vm->__checkArgSize(args, 1); _Int i = vm->PyInt_AS_C(args[0]); @@ -146,6 +151,11 @@ void __initializeBuiltinFunctions(VM* _vm) { return args[0]->_type; }); + _vm->bindMethod("type", "__eq__", [](VM* vm, const pkpy::ArgList& args) { + vm->__checkArgSize(args, 2, true); + return vm->PyBool(args[0] == args[1]); + }); + _vm->bindMethod("range", "__new__", [](VM* vm, const pkpy::ArgList& args) { _Range r; switch (args.size()) { diff --git a/src/safestl.h b/src/safestl.h index 8a04472e..7a3bc673 100644 --- a/src/safestl.h +++ b/src/safestl.h @@ -35,11 +35,7 @@ public: using std::vector::vector; }; - -class PyVarDict: public emhash8::HashMap<_Str, PyVar> { - using emhash8::HashMap<_Str, PyVar>::HashMap; -}; - +typedef emhash8::HashMap<_Str, PyVar> PyVarDict; namespace pkpy { const uint8_t MAX_POOLING_N = 10; diff --git a/src/str.h b/src/str.h index b1f33780..e77574ae 100644 --- a/src/str.h +++ b/src/str.h @@ -46,8 +46,7 @@ public: size_t hash() const{ if(!hash_initialized){ - //_hash = std::hash()(*this); - _hash = emhash8::HashMap::wyhashstr(data(), size()); + _hash = std::hash()(*this); hash_initialized = true; } return _hash; diff --git a/tests/_basic.py b/tests/_basic.py index de69b8af..be06473a 100644 --- a/tests/_basic.py +++ b/tests/_basic.py @@ -106,3 +106,14 @@ assert [1, 2, 3] * 3 == [1, 2, 3, 1, 2, 3, 1, 2, 3] a = 5 assert ((a > 3) ? 1 : 0) == 1 assert ((a < 3) ? 1 : 0) == 0 + +assert eq(round(3.1415926, 2), 3.14) +assert eq(round(3.1415926, 3), 3.142) +assert eq(round(3.1415926, 4), 3.1416) +assert eq(round(-3.1415926, 2), -3.14) +assert eq(round(-3.1415926, 3), -3.142) +assert eq(round(-3.1415926, 4), -3.1416) +assert round(23.2) == 23 +assert round(23.8) == 24 +assert round(-23.2) == -23 +assert round(-23.8) == -24 diff --git a/tests/_functions.py b/tests/_functions.py index 9989c488..651e790d 100644 --- a/tests/_functions.py +++ b/tests/_functions.py @@ -38,3 +38,14 @@ assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, e=1) == 58 assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) == 217 assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, d=1, e=2) == 213 +a = 1 +b = 2 + +def f(): + global a, b + a = 3 + b = 4 + +f() +assert a == 3 +assert b == 4 \ No newline at end of file