From db52fb50515cde9094869b9e40bfc6b293564cf1 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 13 Dec 2022 00:05:21 +0800 Subject: [PATCH] add set --- src/builtins.h | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/compiler.h | 4 +-- src/pocketpy.h | 3 ++ src/vm.h | 2 +- tests/_set.py | 79 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 tests/_set.py diff --git a/src/builtins.h b/src/builtins.h index 53a3c5c9..cf5b6889 100644 --- a/src/builtins.h +++ b/src/builtins.h @@ -296,6 +296,103 @@ class FileIO: def open(path, mode='r'): return FileIO(path, mode) + + +class set: + def __init__(self, iterable=None): + iterable = iterable or [] + self._a = dict() + for item in iterable: + self.add(item) + + def add(self, elem): + self._a[elem] = None + + def discard(self, elem): + if elem in self._a: + del self._a[elem] + + def remove(self, elem): + del self._a[elem] + + def clear(self): + self._a.clear() + + def update(self,other): + for elem in other: + self.add(elem) + return self + + def __len__(self): + return len(self._a) + + def copy(self): + return set(self._a.keys()) + + def __and__(self, other): + ret = set() + for elem in self: + if elem in other: + ret.add(elem) + return ret + + def __or__(self, other): + ret = self.copy() + for elem in other: + ret.add(elem) + return ret + + def __sub__(self, other): + ret = set() + for elem in self: + if elem not in other: + ret.add(elem) + return ret + + def __xor__(self, other): + ret = set() + for elem in self: + if elem not in other: + ret.add(elem) + for elem in other: + if elem not in self: + ret.add(elem) + return ret + + def union(self, other): + return self | other + + def intersection(self, other): + return self & other + + def difference(self, other): + return self - other + + def symmetric_difference(self, other): + return self ^ other + + def __eq__(self, other): + return self.__xor__(other).__len__() == 0 + + def isdisjoint(self, other): + return self.__and__(other).__len__() == 0 + + def issubset(self, other): + return self.__sub__(other).__len__() == 0 + + def issuperset(self, other): + return other.__sub__(self).__len__() == 0 + + def __contains__(self, elem): + return elem in self._a + + def __repr__(self): + if len(self) == 0: + return 'set()' + return '{'+ ', '.join(self._a.keys()) + '}' + + def __iter__(self): + return self._a.keys().__iter__() )"; const char* __OS_CODE = R"( diff --git a/src/compiler.h b/src/compiler.h index 956d2b3c..25e9d4dd 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -409,7 +409,7 @@ public: case TK("*="): emitCode(OP_BINARY_OP, 2); break; case TK("/="): emitCode(OP_BINARY_OP, 3); break; case TK("//="): emitCode(OP_BINARY_OP, 4); break; - + case TK("%="): emitCode(OP_BINARY_OP, 5); break; case TK("&="): emitCode(OP_BITWISE_OP, 2); break; case TK("|="): emitCode(OP_BITWISE_OP, 3); break; @@ -575,7 +575,7 @@ __LISTCOMP: matchNewLines(); consume(TK("}")); - if(parsing_dict) emitCode(OP_BUILD_MAP, size); + if(size == 0 || parsing_dict) emitCode(OP_BUILD_MAP, size); else emitCode(OP_BUILD_SET, size); } diff --git a/src/pocketpy.h b/src/pocketpy.h index fc6699c6..45078cb7 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -118,6 +118,9 @@ void __initializeBuiltinFunctions(VM* _vm) { } PyVarList ret; for (const auto& name : names) ret.push_back(vm->PyStr(name)); + std::sort(ret.begin(), ret.end(), [vm](const PyVar& a, const PyVar& b) { + return vm->PyStr_AS_C(a) < vm->PyStr_AS_C(b); + }); return vm->PyList(ret); }); diff --git a/src/vm.h b/src/vm.h index 4ebceeae..2643deb7 100644 --- a/src/vm.h +++ b/src/vm.h @@ -285,7 +285,7 @@ protected: PyVar obj = frame->popValue(this); PyVarOrNull iter_fn = getAttr(obj, __iter__, false); if(iter_fn != nullptr){ - PyVar tmp = call(iter_fn, pkpy::oneArg(obj)); + PyVar tmp = call(iter_fn); PyVarRef var = frame->__pop(); __checkType(var, _tp_ref); PyIter_AS_C(tmp)->var = var; diff --git a/tests/_set.py b/tests/_set.py new file mode 100644 index 00000000..1f69a1c3 --- /dev/null +++ b/tests/_set.py @@ -0,0 +1,79 @@ +a = {1, 2, 3} +a |= {2, 3, 4} + +assert a == {1, 2, 3, 4} + +a = {1, 2, 3} +a &= {2, 3, 4} + +assert a == {2, 3} + +a = {1, 2, 3} +a ^= {2, 3, 4} + +assert a == {1, 4} + +a = {1, 2, 3} +a -= {2, 3, 4} + +assert a == {1} + +a = {1, 2, 3} +a |= {2, 3, 4} + +assert a == {1, 2, 3, 4} + +a = set([1, 2, 3]) +a |= set([2, 3, 4]) + +assert a == {1, 2, 3, 4} + +a.add(5) +assert a == {1, 2, 3, 4, 5} + +a.remove(5) +assert a == {1, 2, 3, 4} + +a.discard(4) +assert a == {1, 2, 3} + +a.discard(4) +assert a == {1, 2, 3} + +assert a.union({2, 3, 4}) == {1, 2, 3, 4} +assert a.intersection({2, 3, 4}) == {2, 3} +assert a.difference({2, 3, 4}) == {1} +assert a.symmetric_difference({2, 3, 4}) == {1, 4} + +assert a | {2, 3, 4} == {1, 2, 3, 4} +assert a & {2, 3, 4} == {2, 3} +assert a - {2, 3, 4} == {1} +assert a ^ {2, 3, 4} == {1, 4} + +a.update({2, 3, 4}) +assert a == {1, 2, 3, 4} + +assert 3 in a +assert 5 not in a + +assert len(a) == 4 +a.clear() + +assert len(a) == 0 +assert a == set() + +b = {1, 2, 3} +c = b.copy() + +assert b == c +assert b is not c +b.add(4) +assert b == {1, 2, 3, 4} +assert c == {1, 2, 3} + +assert type({}) is dict + +assert {1,2}.issubset({1,2,3}) +assert {1,2,3}.issuperset({1,2}) +assert {1,2,3}.isdisjoint({4,5,6}) +assert not {1,2,3}.isdisjoint({2,3,4}) \ No newline at end of file