raise error on mismatched eq/ne

This commit is contained in:
blueloveTH 2024-12-13 16:40:18 +08:00
parent dc1b6255a0
commit 5ca7abed5c
5 changed files with 26 additions and 7 deletions

View File

@ -954,7 +954,6 @@ FrameResult VM__run_top_frame(VM* self) {
} }
case OP_END_CLASS: { case OP_END_CLASS: {
// [cls or decorated] // [cls or decorated]
// TODO: if __eq__ is defined, check __ne__ and provide a default implementation
py_Name name = byte.arg; py_Name name = byte.arg;
// set into f_globals // set into f_globals
py_setdict(frame->module, name, TOP()); py_setdict(frame->module, name, TOP());
@ -966,7 +965,15 @@ FrameResult VM__run_top_frame(VM* self) {
py_TypeInfo* base_ti = ti->base_ti; py_TypeInfo* base_ti = ti->base_ti;
if(base_ti->on_end_subclass) base_ti->on_end_subclass(ti); if(base_ti->on_end_subclass) base_ti->on_end_subclass(ti);
} }
if(!py_isnil(&ti->magic[__eq__])) {
if(py_isnil(&ti->magic[__ne__])) {
TypeError("'%n' implements '__eq__' but not '__ne__'", ti->name);
goto __ERROR;
}
}
} }
// class with decorator is unsafe currently
// it skips the above check
POP(); POP();
self->__curr_class = NULL; self->__curr_class = NULL;
DISPATCH(); DISPATCH();

View File

@ -129,15 +129,15 @@ __RETRY:
Dict__ctor(self, new_capacity, old_dict.entries.capacity); Dict__ctor(self, new_capacity, old_dict.entries.capacity);
// move entries from old dict to new dict // move entries from old dict to new dict
for(int i = 0; i < old_dict.entries.length; i++) { for(int i = 0; i < old_dict.entries.length; i++) {
DictEntry* entry = c11__at(DictEntry, &old_dict.entries, i); DictEntry* old_entry = c11__at(DictEntry, &old_dict.entries, i);
if(py_isnil(&entry->key)) continue; if(py_isnil(&old_entry->key)) continue;
int idx = entry->hash % new_capacity; int idx = old_entry->hash % new_capacity;
bool success = false; bool success = false;
for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) { for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
int idx2 = self->indices[idx]._[i]; int idx2 = self->indices[idx]._[i];
if(idx2 == -1) { if(idx2 == -1) {
// insert new entry (empty slot) // insert new entry (empty slot)
c11_vector__push(DictEntry, &self->entries, *entry); c11_vector__push(DictEntry, &self->entries, *old_entry);
self->indices[idx]._[i] = self->entries.length - 1; self->indices[idx]._[i] = self->entries.length - 1;
self->length++; self->length++;
success = true; success = true;
@ -210,8 +210,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
} }
// no empty slot found // no empty slot found
if(self->capacity >= (uint32_t)self->entries.length * 10) { if(self->capacity >= (uint32_t)self->entries.length * 10) {
// raise error if we reach the minimum load factor (10%) return RuntimeError("dict: %d/%d/%d: minimum load factor reached",
return RuntimeError("dict has too much collision: %d/%d/%d",
self->entries.length, self->entries.length,
self->entries.capacity, self->entries.capacity,
self->capacity); self->capacity);

View File

@ -209,6 +209,8 @@ assertEqual((n+1) not in d, True)
class BadCmp: class BadCmp:
def __eq__(self, other): def __eq__(self, other):
raise RuntimeError raise RuntimeError
def __ne__(self, other):
raise RuntimeError
# # Test detection of comparison exceptions # # Test detection of comparison exceptions

View File

@ -382,6 +382,7 @@ a = hash(object()) # object is hashable
a = hash(A()) # A is hashable a = hash(A()) # A is hashable
class B: class B:
def __eq__(self, o): return True def __eq__(self, o): return True
def __ne__(self, o): return False
try: try:
hash(B()) hash(B())

View File

@ -103,3 +103,13 @@ class Context:
for _ in range(5): for _ in range(5):
with Context() as x: with Context() as x:
assert x == 1 assert x == 1
# bad dict hash
class A:
def __eq__(self, o): return False
def __ne__(self, o): return True
def __hash__(self): return 1
bad_dict = {A(): 1, A(): 2, A(): 3, A(): 4}
assert len(bad_dict) == 4
bad_dict[A()] = 5 # error