improve dict

This commit is contained in:
blueloveTH 2025-06-29 21:44:57 +08:00
parent be2aae493a
commit caf7505dc2
6 changed files with 161 additions and 95 deletions

View File

@ -12,9 +12,9 @@ typedef struct {
typedef struct { typedef struct {
int length; int length;
uint32_t capacity; uint32_t capacity;
void* indices;
bool index_is_short;
uint32_t null_index_value; uint32_t null_index_value;
bool index_is_short;
void* indices;
c11_vector /*T=DictEntry*/ entries; c11_vector /*T=DictEntry*/ entries;
} Dict; } Dict;

View File

@ -124,3 +124,4 @@ bool METHOD(contains)(NAME* self, K key) {
#undef less #undef less
#undef partial_less #undef partial_less
#undef equal #undef equal
#undef hash

View File

@ -101,18 +101,24 @@ bool NameDict__del(NameDict* self, py_Name key) {
self->items[i].key = NULL; self->items[i].key = NULL;
self->items[i].value = *py_NIL(); self->items[i].value = *py_NIL();
self->length--; self->length--;
// tidy /* tidy */
uintptr_t pre_z = i; uint32_t posToRemove = i;
uintptr_t z = (i + 1) & self->mask; uint32_t posToShift = posToRemove;
while(self->items[z].key != NULL) { while(true) {
uintptr_t h = (uintptr_t)self->items[z].key & self->mask; posToShift = (posToShift + 1) & self->mask;
if(h != i) break; if(self->items[posToShift].key == NULL) break;
// std::swap(_items[pre_z], _items[z]); uintptr_t hash_z = (uintptr_t)self->items[posToShift].key;
NameDict_KV tmp = self->items[pre_z]; uintptr_t insertPos = hash_z & self->mask;
self->items[pre_z] = self->items[z]; bool cond1 = insertPos <= posToRemove;
self->items[z] = tmp; bool cond2 = posToRemove <= posToShift;
pre_z = z; if((cond1 && cond2) ||
z = (z + 1) & self->mask; // chain wrapped around capacity
(posToShift < insertPos && (cond1 || cond2))) {
NameDict_KV tmp = self->items[posToRemove];
self->items[posToRemove] = self->items[posToShift];
self->items[posToShift] = tmp;
posToRemove = posToShift;
}
} }
return true; return true;
} }

View File

@ -5,6 +5,16 @@
#include "pocketpy/interpreter/types.h" #include "pocketpy/interpreter/types.h"
#include "pocketpy/interpreter/vm.h" #include "pocketpy/interpreter/vm.h"
typedef struct {
Dict* dict; // weakref for slot 0
Dict dict_backup;
DictEntry* curr;
DictEntry* end;
int mode; // 0: keys, 1: values, 2: items
} DictIterator;
#define Dict__step(x) ((x) < mask ? (x) + 1 : 0)
static uint32_t Dict__next_cap(uint32_t cap) { static uint32_t Dict__next_cap(uint32_t cap) {
switch(cap) { switch(cap) {
case 7: return 17; case 7: return 17;
@ -51,27 +61,31 @@ static uint32_t Dict__next_cap(uint32_t cap) {
} }
} }
typedef struct { static uint64_t Dict__hash(uint64_t key) {
DictEntry* curr; // https://gist.github.com/badboy/6267743
DictEntry* end; key = (~key) + (key << 21); // key = (key << 21) - key - 1
int mode; // 0: keys, 1: values, 2: items key = key ^ (key >> 24);
} DictIterator; key = (key + (key << 3)) + (key << 8); // key * 265
key = key ^ (key >> 14);
key = (key + (key << 2)) + (key << 4); // key * 21
key = key ^ (key >> 28);
key = key + (key << 31);
return key;
}
static void Dict__ctor(Dict* self, uint32_t capacity, int entries_capacity) { static void Dict__ctor(Dict* self, uint32_t capacity, int entries_capacity) {
self->length = 0; self->length = 0;
self->capacity = capacity; // the 1st prime self->capacity = capacity;
size_t indices_size; size_t indices_size;
if(self->capacity < UINT16_MAX - 1) { if(self->capacity < UINT16_MAX) {
self->index_is_short = true; self->index_is_short = true;
indices_size = self->capacity * sizeof(uint16_t); indices_size = self->capacity * sizeof(uint16_t);
self->null_index_value = UINT16_MAX; self->null_index_value = UINT16_MAX;
self->deleted_index_value = UINT16_MAX - 1;
} else { } else {
self->index_is_short = false; self->index_is_short = false;
indices_size = self->capacity * sizeof(uint32_t); indices_size = self->capacity * sizeof(uint32_t);
self->null_index_value = UINT32_MAX; self->null_index_value = UINT32_MAX;
self->deleted_index_value = UINT32_MAX - 1;
} }
self->indices = PK_MALLOC(indices_size); self->indices = PK_MALLOC(indices_size);
@ -98,17 +112,17 @@ static uint32_t Dict__get_index(Dict* self, uint32_t index) {
} }
} }
static void Dict__swap_index(Dict* self, uint32_t x, uint32_t y) { static void Dict__swap_null_index(Dict* self, uint32_t pre_z, uint32_t z) {
if(self->index_is_short) { if(self->index_is_short) {
uint16_t* indices = self->indices; uint16_t* indices = self->indices;
uint16_t tmp = indices[x]; assert(indices[pre_z] == UINT16_MAX);
indices[x] = indices[y]; indices[pre_z] = indices[z];
indices[y] = tmp; indices[z] = UINT16_MAX;
} else { } else {
uint32_t* indices = self->indices; uint32_t* indices = self->indices;
uint32_t tmp = indices[x]; assert(indices[pre_z] == UINT32_MAX);
indices[x] = indices[y]; indices[pre_z] = indices[z];
indices[y] = tmp; indices[z] = UINT32_MAX;
} }
} }
@ -122,17 +136,21 @@ static void Dict__set_index(Dict* self, uint32_t index, uint32_t value) {
} }
} }
static bool static bool Dict__probe(Dict* self,
Dict__probe(Dict* self, py_TValue* key, py_i64* p_hash, uint32_t* p_idx, DictEntry** p_entry) { py_TValue* key,
if(!py_hash(key, p_hash)) return false; uint64_t* p_hash,
py_i64 hash = *p_hash; uint32_t* p_idx,
uint32_t idx = (uint64_t)hash % self->capacity; DictEntry** p_entry) {
const uint32_t max_idx = self->capacity - 1; py_i64 h_user;
if(!py_hash(key, &h_user)) return false;
*p_hash = Dict__hash((uint64_t)h_user);
uint32_t mask = self->capacity - 1;
uint32_t idx = (*p_hash) % self->capacity;
while(true) { while(true) {
uint32_t idx2 = Dict__get_index(self, idx); uint32_t idx2 = Dict__get_index(self, idx);
if(idx2 == self->null_index_value) break; if(idx2 == self->null_index_value) break;
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
if(entry->hash == (uint64_t)hash) { if(entry->hash == (*p_hash)) {
int res = py_equal(&entry->key, key); int res = py_equal(&entry->key, key);
if(res == 1) { if(res == 1) {
*p_idx = idx; *p_idx = idx;
@ -142,7 +160,7 @@ static bool
if(res == -1) return false; // error if(res == -1) return false; // error
} }
// try next index // try next index
idx = idx < max_idx ? idx + 1 : 0; idx = Dict__step(idx);
} }
// not found // not found
*p_idx = idx; *p_idx = idx;
@ -151,7 +169,7 @@ static bool
} }
static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) { static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) {
py_i64 hash; uint64_t hash;
uint32_t idx; uint32_t idx;
return Dict__probe(self, key, &hash, &idx, out); return Dict__probe(self, key, &hash, &idx, out);
} }
@ -166,14 +184,14 @@ static void Dict__clear(Dict* self) {
static void Dict__rehash_2x(Dict* self) { static void Dict__rehash_2x(Dict* self) {
Dict old_dict = *self; Dict old_dict = *self;
uint32_t new_capacity = Dict__next_cap(new_capacity); uint32_t new_capacity = Dict__next_cap(old_dict.capacity);
uint32_t mask = new_capacity - 1;
// create a new dict with new capacity // create a new dict with new capacity
Dict__ctor(self, new_capacity, old_dict.entries.capacity); Dict__ctor(self, new_capacity, old_dict.entries.capacity);
// move entries from old dict to new dict // move entries from old dict to new dict
const uint32_t max_idx = new_capacity - 1;
for(int i = 0; i < old_dict.entries.length; i++) { for(int i = 0; i < old_dict.entries.length; i++) {
DictEntry* old_entry = c11__at(DictEntry, &old_dict.entries, i); DictEntry* old_entry = c11__at(DictEntry, &old_dict.entries, i);
if(py_isnil(&old_entry->key)) continue; if(py_isnil(&old_entry->key)) continue; // skip deleted
uint32_t idx = old_entry->hash % new_capacity; uint32_t idx = old_entry->hash % new_capacity;
while(true) { while(true) {
uint32_t idx2 = Dict__get_index(self, idx); uint32_t idx2 = Dict__get_index(self, idx);
@ -184,14 +202,14 @@ static void Dict__rehash_2x(Dict* self) {
break; break;
} }
// try next index // try next index
idx = idx < max_idx ? idx + 1 : 0; idx = Dict__step(idx);
} }
} }
Dict__dtor(&old_dict); Dict__dtor(&old_dict);
} }
static void Dict__compact_entries(Dict* self) { static void Dict__compact_entries(Dict* self) {
int* mappings = PK_MALLOC(self->entries.length * sizeof(int)); uint32_t* mappings = PK_MALLOC(self->entries.length * sizeof(uint32_t));
int n = 0; int n = 0;
for(int i = 0; i < self->entries.length; i++) { for(int i = 0; i < self->entries.length; i++) {
@ -215,7 +233,7 @@ static void Dict__compact_entries(Dict* self) {
} }
static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) { static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
py_i64 hash; uint64_t hash;
uint32_t idx; uint32_t idx;
DictEntry* entry; DictEntry* entry;
if(!Dict__probe(self, key, &hash, &idx, &entry)) return false; if(!Dict__probe(self, key, &hash, &idx, &entry)) return false;
@ -226,21 +244,22 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
} }
// insert new entry // insert new entry
DictEntry* new_entry = c11_vector__emplace(&self->entries); DictEntry* new_entry = c11_vector__emplace(&self->entries);
new_entry->hash = (uint64_t)hash; new_entry->hash = hash;
new_entry->key = *key; new_entry->key = *key;
new_entry->val = *val; new_entry->val = *val;
Dict__set_index(self, idx, self->entries.length - 1); Dict__set_index(self, idx, self->entries.length - 1);
self->length++; self->length++;
// check if we need to rehash // check if we need to rehash
float load_factor = (float)self->length / self->capacity; float load_factor = (float)self->length / self->capacity;
if(load_factor > 4 / 7.0f) Dict__rehash_2x(self); if(load_factor > 0.572) Dict__rehash_2x(self);
return true; return true;
} }
/// Delete an entry from the dict. /// Delete an entry from the dict.
/// -1: error, 0: not found, 1: found and deleted /// -1: error, 0: not found, 1: found and deleted
static int Dict__pop(Dict* self, py_Ref key) { static int Dict__pop(Dict* self, py_Ref key) {
py_i64 hash; // Dict__log_index(self, "before pop");
uint64_t hash;
uint32_t idx; uint32_t idx;
DictEntry* entry; DictEntry* entry;
if(!Dict__probe(self, key, &hash, &idx, &entry)) return -1; if(!Dict__probe(self, key, &hash, &idx, &entry)) return -1;
@ -252,26 +271,50 @@ static int Dict__pop(Dict* self, py_Ref key) {
py_newnil(&entry->key); py_newnil(&entry->key);
py_newnil(&entry->val); py_newnil(&entry->val);
self->length--; self->length--;
// tidy indices
uint32_t pre_z = idx; /* tidy */
const uint32_t max_idx = self->capacity - 1; // https://github.com/OpenHFT/Chronicle-Map/blob/820573a68471509ffc1b0584454f4a67c0be1b84/src/main/java/net/openhft/chronicle/hash/impl/CompactOffHeapLinearHashTable.java#L156
uint32_t z = idx < max_idx ? idx + 1 : 0; uint32_t mask = self->capacity - 1;
uint32_t posToRemove = idx;
uint32_t posToShift = posToRemove;
// int probe_count = 0;
// int swap_count = 0;
while(true) { while(true) {
uint32_t idx2 = Dict__get_index(self, z); posToShift = Dict__step(posToShift);
if(idx2 == self->null_index_value) break; uint32_t idx_z = Dict__get_index(self, posToShift);
uint64_t h = c11__at(DictEntry, &self->entries, idx2)->hash; if(idx_z == self->null_index_value) break;
if(h != hash) break; uint64_t hash_z = c11__at(DictEntry, &self->entries, idx_z)->hash;
Dict__swap_index(self, pre_z, z); uint32_t insertPos = (uint64_t)hash_z % self->capacity;
pre_z = z; // the following condition essentially means circular permutations
z = z < max_idx ? z + 1 : 0; // of three (r = posToRemove, s = posToShift, i = insertPos)
// positions are accepted:
// [...i..r...s.] or
// [...r..s...i.] or
// [...s..i...r.]
bool cond1 = insertPos <= posToRemove;
bool cond2 = posToRemove <= posToShift;
if((cond1 && cond2) ||
// chain wrapped around capacity
(posToShift < insertPos && (cond1 || cond2))) {
Dict__swap_null_index(self, posToRemove, posToShift);
posToRemove = posToShift;
// swap_count++;
} }
// probe_count++;
}
// printf("Dict__pop: probe_count=%d, swap_count=%d\n", probe_count, swap_count);
// compact entries if necessary // compact entries if necessary
if(self->entries.length > 16 && self->length < self->entries.length / 2) if(self->entries.length > 16 && (self->length < self->entries.length >> 1)) {
Dict__compact_entries(self); Dict__compact_entries(self); // compact entries
}
// Dict__log_index(self, "after pop");
return 1; return 1;
} }
static void DictIterator__ctor(DictIterator* self, Dict* dict, int mode) { static void DictIterator__ctor(DictIterator* self, Dict* dict, int mode) {
assert(mode >= 0 && mode <= 2);
self->dict = dict;
self->dict_backup = *dict; // backup the dict
self->curr = dict->entries.data; self->curr = dict->entries.data;
self->end = self->curr + dict->entries.length; self->end = self->curr + dict->entries.length;
self->mode = mode; self->mode = mode;
@ -286,6 +329,10 @@ static DictEntry* DictIterator__next(DictIterator* self) {
return retval; return retval;
} }
static bool DictIterator__modified(DictIterator* self) {
return memcmp(self->dict, &self->dict_backup, sizeof(Dict)) != 0;
}
/////////////////////////////// ///////////////////////////////
static bool dict__new__(int argc, py_Ref argv) { static bool dict__new__(int argc, py_Ref argv) {
py_Type cls = py_totype(argv); py_Type cls = py_totype(argv);
@ -455,12 +502,17 @@ static bool dict_copy(int argc, py_Ref argv) {
PY_CHECK_ARGC(1); PY_CHECK_ARGC(1);
Dict* self = py_touserdata(argv); Dict* self = py_touserdata(argv);
Dict* new_dict = py_newobject(py_retval(), tp_dict, 0, sizeof(Dict)); Dict* new_dict = py_newobject(py_retval(), tp_dict, 0, sizeof(Dict));
new_dict->capacity = self->capacity;
new_dict->length = self->length; new_dict->length = self->length;
new_dict->capacity = self->capacity;
new_dict->null_index_value = self->null_index_value;
new_dict->index_is_short = self->index_is_short;
// copy entries
new_dict->entries = c11_vector__copy(&self->entries); new_dict->entries = c11_vector__copy(&self->entries);
// copy indices // copy indices
new_dict->indices = PK_MALLOC(new_dict->capacity * sizeof(DictIndex)); size_t indices_size = self->index_is_short ? self->capacity * sizeof(uint16_t)
memcpy(new_dict->indices, self->indices, new_dict->capacity * sizeof(DictIndex)); : self->capacity * sizeof(uint32_t);
new_dict->indices = PK_MALLOC(indices_size);
memcpy(new_dict->indices, self->indices, indices_size);
return true; return true;
} }
@ -557,6 +609,7 @@ py_Type pk_dict__register() {
static bool dict_items__next__(int argc, py_Ref argv) { static bool dict_items__next__(int argc, py_Ref argv) {
PY_CHECK_ARGC(1); PY_CHECK_ARGC(1);
DictIterator* iter = py_touserdata(py_arg(0)); DictIterator* iter = py_touserdata(py_arg(0));
if(DictIterator__modified(iter)) return RuntimeError("dictionary modified during iteration");
DictEntry* entry = (DictIterator__next(iter)); DictEntry* entry = (DictIterator__next(iter));
if(entry) { if(entry) {
switch(iter->mode) { switch(iter->mode) {
@ -670,4 +723,4 @@ bool py_dict_apply(py_Ref self, bool (*f)(py_Ref, py_Ref, void*), void* ctx) {
return true; return true;
} }
#undef PK_DICT_MAX_COLLISION #undef Dict__step

View File

@ -115,30 +115,7 @@ assert a.pop(1) == 2
assert a.pop(1, None) is None assert a.pop(1, None) is None
n = 2 ** 17 # test getitem
a = {}
for i in range(n):
a[str(i)] = i
for i in range(n):
y = a[str(i)]
for i in range(n):
del a[str(i)]
# namedict delete test
# class A: pass
# a = A()
# b = ['0', '1']
# for i in range(len(data)):
# z = data[i]
# setattr(a, str(z), i)
# b.append(z)
# if i % 3 == 0:
# y = b.pop()
# delattr(a, y)
d = {} d = {}
for i in range(-1000, 1000): for i in range(-1000, 1000):
d[i] = i d[i] = i
@ -155,3 +132,37 @@ assert list(d) == ['1', 222, '333']
assert list(d.keys()) == ['1', 222, '333'] assert list(d.keys()) == ['1', 222, '333']
assert list(d.values()) == [1, 2, 3] assert list(d.values()) == [1, 2, 3]
assert list(d.items()) == [('1', 1), (222, 2), ('333', 3)] assert list(d.items()) == [('1', 1), (222, 2), ('333', 3)]
# test del
n = 2 ** 17
a = {}
for i in range(n):
a[str(i)] = i
for i in range(n):
del a[str(i)]
assert len(a) == 0
# test del with int keys
if 0:
n = 2 ** 17
a = {}
for i in range(n):
a[i] = i
for i in range(n):
del a[i]
assert len(a) == 0
#######################
# namedict delete test
class A: pass
a = A()
b = ['0', '1']
for i in range(len(data)):
z = data[i]
setattr(a, str(z), i)
b.append(z)
if i % 3 == 0:
y = b.pop()
delattr(a, y)

View File

@ -103,9 +103,4 @@ class A:
bad_dict = {A(): 1, A(): 2, A(): 3, A(): 4} bad_dict = {A(): 1, A(): 2, A(): 3, A(): 4}
assert len(bad_dict) == 4 assert len(bad_dict) == 4
try:
bad_dict[A()] = 5 # error
exit(1)
except RuntimeError as e:
assert 'maximum collision reached' in str(e)