diff --git a/src/objects/dict.c b/src/objects/dict.c index f1c91590..31da0e7a 100644 --- a/src/objects/dict.c +++ b/src/objects/dict.c @@ -69,7 +69,7 @@ static void pkpy_Dict__htset(pkpy_Dict* self, int h, int v) { } } -static int pkpy_Dict__probe(const pkpy_Dict* self, void* vm, pkpy_Var key, int64_t hash) { +static int pkpy_Dict__probe0(const pkpy_Dict* self, void* vm, pkpy_Var key, int64_t hash) { const int null = pkpy_Dict__idx_null(self); const int mask = self->_htcap - 1; for(int h = hash & mask;; h = (h + 1) & mask) { @@ -83,6 +83,19 @@ static int pkpy_Dict__probe(const pkpy_Dict* self, void* vm, pkpy_Var key, int64 PK_UNREACHABLE(); } +static int pkpy_Dict__probe1(const pkpy_Dict* self, void* vm, pkpy_Var key, int64_t hash) { + const int null = pkpy_Dict__idx_null(self); + const int mask = self->_htcap - 1; + for(int h = hash & mask;; h = (h + 1) & mask) { + int idx = pkpy_Dict__htget(self, h); + if(idx == null) return h; + + struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, idx); + if(entry->hash == hash && pkpy_Var__eq__(vm, entry->key, key)) return h; + } + PK_UNREACHABLE(); +} + static void pkpy_Dict__extendht(pkpy_Dict* self, void* vm) { self->_version += 1; free(self->_hashtable); @@ -94,14 +107,14 @@ static void pkpy_Dict__extendht(pkpy_Dict* self, void* vm) { struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, i); if(pkpy_Var__is_null(&entry->key)) continue; - int h = pkpy_Dict__probe(self, vm, entry->key, entry->hash); + int h = pkpy_Dict__probe0(self, vm, entry->key, entry->hash); pkpy_Dict__htset(self, h, i); } } bool pkpy_Dict__set(pkpy_Dict* self, void* vm, pkpy_Var key, pkpy_Var val) { int hash = pkpy_Var__hash__(vm, key); - int h = pkpy_Dict__probe(self, vm, key, hash); + int h = pkpy_Dict__probe1(self, vm, key, hash); int idx = pkpy_Dict__htget(self, h); if(idx == pkpy_Dict__idx_null(self)) { @@ -114,6 +127,7 @@ bool pkpy_Dict__set(pkpy_Dict* self, void* vm, pkpy_Var key, pkpy_Var val) { .key = key, .val = val, })); + h = pkpy_Dict__probe0(self, vm, key, hash); pkpy_Dict__htset(self, h, idx); self->count += 1; if(self->count >= self->_htcap * 0.75) pkpy_Dict__extendht(self, vm); @@ -121,17 +135,31 @@ bool pkpy_Dict__set(pkpy_Dict* self, void* vm, pkpy_Var key, pkpy_Var val) { } struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, idx); - assert(entry->hash == hash && pkpy_Var__eq__(vm, entry->key, key)); - entry->val = val; + + if(entry->hash == hash || pkpy_Var__eq__(vm, entry->key, key)) { + entry->val = val; + } else { + self->_version += 1; + self->count += 1; + h = pkpy_Dict__probe0(self, vm, key, hash); + idx = pkpy_Dict__htget(self, h); + struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, idx); + entry->key = key; + entry->val = val; + entry->hash = hash; + } return false; } bool pkpy_Dict__contains(const pkpy_Dict* self, void* vm, pkpy_Var key) { int hash = pkpy_Var__hash__(vm, key); - int h = pkpy_Dict__probe(self, vm, key, hash); + int h = pkpy_Dict__probe1(self, vm, key, hash); int idx = pkpy_Dict__htget(self, h); if(idx == pkpy_Dict__idx_null(self)) return false; + + struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, idx); + assert(entry->hash == hash && pkpy_Var__eq__(vm, entry->key, key)); return true; } @@ -155,7 +183,7 @@ static bool pkpy_Dict__refactor(pkpy_Dict* self, void* vm) { int j = self->_entries.count; c11_vector__push(struct pkpy_DictEntry, &self->_entries, *entry); - int h = pkpy_Dict__probe(self, vm, entry->key, entry->hash); + int h = pkpy_Dict__probe0(self, vm, entry->key, entry->hash); pkpy_Dict__htset(self, h, j); } c11_vector__dtor(&old_entries); @@ -164,15 +192,14 @@ static bool pkpy_Dict__refactor(pkpy_Dict* self, void* vm) { bool pkpy_Dict__del(pkpy_Dict* self, void* vm, pkpy_Var key) { int hash = pkpy_Var__hash__(vm, key); - int h = pkpy_Dict__probe(self, vm, key, hash); + int h = pkpy_Dict__probe1(self, vm, key, hash); int idx = pkpy_Dict__htget(self, h), null = pkpy_Dict__idx_null(self); if(idx == null) return false; - self->_version += 1; struct pkpy_DictEntry* entry = &c11__getitem(struct pkpy_DictEntry, &self->_entries, idx); assert(entry->hash == hash && pkpy_Var__eq__(vm, entry->key, key)); + self->_version += 1; pkpy_Var__set_null(&entry->key); - pkpy_Dict__htset(self, h, null); self->count -= 1; pkpy_Dict__refactor(self, vm); return true; @@ -180,7 +207,7 @@ bool pkpy_Dict__del(pkpy_Dict* self, void* vm, pkpy_Var key) { const pkpy_Var *pkpy_Dict__try_get(const pkpy_Dict* self, void* vm, pkpy_Var key) { int hash = pkpy_Var__hash__(vm, key); - int h = pkpy_Dict__probe(self, vm, key, hash); + int h = pkpy_Dict__probe1(self, vm, key, hash); int idx = pkpy_Dict__htget(self, h); if(idx == pkpy_Dict__idx_null(self)) return NULL;