diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 58cb747b..ae6816e0 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -954,7 +954,6 @@ FrameResult VM__run_top_frame(VM* self) { } case OP_END_CLASS: { // [cls or decorated] - // TODO: if __eq__ is defined, check __ne__ and provide a default implementation py_Name name = byte.arg; // set into f_globals py_setdict(frame->module, name, TOP()); @@ -966,7 +965,15 @@ FrameResult VM__run_top_frame(VM* self) { py_TypeInfo* base_ti = ti->base_ti; if(base_ti->on_end_subclass) base_ti->on_end_subclass(ti); } + if(!py_isnil(&ti->magic[__eq__])) { + if(py_isnil(&ti->magic[__ne__])) { + TypeError("'%n' implements '__eq__' but not '__ne__'", ti->name); + goto __ERROR; + } + } } + // class with decorator is unsafe currently + // it skips the above check POP(); self->__curr_class = NULL; DISPATCH(); diff --git a/src/public/py_dict.c b/src/public/py_dict.c index 4bf44c1b..01227224 100644 --- a/src/public/py_dict.c +++ b/src/public/py_dict.c @@ -129,15 +129,15 @@ __RETRY: Dict__ctor(self, new_capacity, old_dict.entries.capacity); // move entries from old dict to new dict for(int i = 0; i < old_dict.entries.length; i++) { - DictEntry* entry = c11__at(DictEntry, &old_dict.entries, i); - if(py_isnil(&entry->key)) continue; - int idx = entry->hash % new_capacity; + DictEntry* old_entry = c11__at(DictEntry, &old_dict.entries, i); + if(py_isnil(&old_entry->key)) continue; + int idx = old_entry->hash % new_capacity; bool success = false; for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) { int idx2 = self->indices[idx]._[i]; if(idx2 == -1) { // insert new entry (empty slot) - c11_vector__push(DictEntry, &self->entries, *entry); + c11_vector__push(DictEntry, &self->entries, *old_entry); self->indices[idx]._[i] = self->entries.length - 1; self->length++; success = true; @@ -210,8 +210,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) { } // no empty slot found if(self->capacity >= (uint32_t)self->entries.length * 10) { - // raise error if we reach the minimum load factor (10%) - return RuntimeError("dict has too much collision: %d/%d/%d", + return RuntimeError("dict: %d/%d/%d: minimum load factor reached", self->entries.length, self->entries.capacity, self->capacity); diff --git a/tests/72_collections.py b/tests/72_collections.py index d6cebef0..5c75782e 100644 --- a/tests/72_collections.py +++ b/tests/72_collections.py @@ -209,6 +209,8 @@ assertEqual((n+1) not in d, True) class BadCmp: def __eq__(self, other): raise RuntimeError + def __ne__(self, other): + raise RuntimeError # # Test detection of comparison exceptions diff --git a/tests/77_builtin_func.py b/tests/77_builtin_func.py index b6abce3c..49abed1a 100644 --- a/tests/77_builtin_func.py +++ b/tests/77_builtin_func.py @@ -382,6 +382,7 @@ a = hash(object()) # object is hashable a = hash(A()) # A is hashable class B: def __eq__(self, o): return True + def __ne__(self, o): return False try: hash(B()) diff --git a/tests/99_extras.py b/tests/99_extras.py index 31612ac5..665937c1 100644 --- a/tests/99_extras.py +++ b/tests/99_extras.py @@ -103,3 +103,13 @@ class Context: for _ in range(5): with Context() as x: assert x == 1 + +# bad dict hash +class A: + def __eq__(self, o): return False + def __ne__(self, o): return True + def __hash__(self): return 1 + +bad_dict = {A(): 1, A(): 2, A(): 3, A(): 4} +assert len(bad_dict) == 4 +bad_dict[A()] = 5 # error