diff --git a/include/pocketpy/dict.h b/include/pocketpy/dict.h index b59899ab..dadc1199 100644 --- a/include/pocketpy/dict.h +++ b/include/pocketpy/dict.h @@ -41,7 +41,8 @@ struct Dict{ int size() const { return _size; } - void _probe(PyObject* key, bool& ok, int& i) const; + void _probe_0(PyObject* key, bool& ok, int& i) const; + void _probe_1(PyObject* key, bool& ok, int& i) const; void set(PyObject* key, PyObject* val); void _rehash(); diff --git a/include/pocketpy/namedict.h b/include/pocketpy/namedict.h index 7871fc0f..3b180edc 100644 --- a/include/pocketpy/namedict.h +++ b/include/pocketpy/namedict.h @@ -29,11 +29,11 @@ struct NameDictImpl { uint16_t _mask; Item* _items; -#define HASH_PROBE(key, ok, i) \ -ok = false; \ -i = _hash(key, _mask, _hash_seed); \ -for(int _j=0; _j<_capacity; _j++) { \ - if(!_items[i].first.empty()){ \ +#define HASH_PROBE_0(key, ok, i) \ +ok = false; \ +i = _hash(key, _mask, _hash_seed); \ +for(int _j=0; _j<_capacity; _j++) { \ + if(!_items[i].first.empty()){ \ if(_items[i].first == (key)) { ok = true; break; } \ }else{ \ if(_items[i].second == 0) break; \ @@ -41,6 +41,14 @@ for(int _j=0; _j<_capacity; _j++) { \ i = (i + 1) & _mask; \ } +#define HASH_PROBE_1(key, ok, i) \ +ok = false; \ +i = _hash(key, _mask, _hash_seed); \ +while(!_items[i].first.empty()) { \ + if(_items[i].first == (key)) { ok = true; break; } \ + i = (i + 1) & _mask; \ +} + #define NAMEDICT_ALLOC() \ _items = (Item*)pool128_alloc(_capacity * sizeof(Item)); \ memset(_items, 0, _capacity * sizeof(Item)); \ @@ -73,19 +81,19 @@ for(int _j=0; _j<_capacity; _j++) { \ T operator[](StrName key) const { bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_0(key, ok, i); if(!ok) throw std::out_of_range(fmt("NameDict key not found: ", key)); return _items[i].second; } void set(StrName key, T val){ bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_1(key, ok, i); if(!ok) { _size++; if(_size > _capacity*_load_factor){ _rehash(true); - HASH_PROBE(key, ok, i); + HASH_PROBE_1(key, ok, i); } _items[i].first = key; } @@ -103,7 +111,7 @@ for(int _j=0; _j<_capacity; _j++) { \ for(uint16_t i=0; i) return nullptr; else if constexpr(std::is_same_v) return -1; @@ -128,14 +136,14 @@ for(int _j=0; _j<_capacity; _j++) { \ T* try_get_2(StrName key) { bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_0(key, ok, i); if(!ok) return nullptr; return &_items[i].second; } bool try_set(StrName key, T val){ bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_1(key, ok, i); if(!ok) return false; _items[i].second = val; return true; @@ -143,7 +151,7 @@ for(int _j=0; _j<_capacity; _j++) { \ bool contains(StrName key) const { bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_0(key, ok, i); return ok; } @@ -156,7 +164,7 @@ for(int _j=0; _j<_capacity; _j++) { \ void erase(StrName key){ bool ok; uint16_t i; - HASH_PROBE(key, ok, i); + HASH_PROBE_0(key, ok, i); if(!ok) throw std::out_of_range(fmt("NameDict key not found: ", key)); _items[i].first = StrName(); // _items[i].second = PY_DELETED_SLOT; // do not change .second if it is not zero, it means the slot is occupied by a deleted item diff --git a/src/dict.cpp b/src/dict.cpp index 8add915a..f94e9981 100644 --- a/src/dict.cpp +++ b/src/dict.cpp @@ -43,7 +43,7 @@ namespace pkpy{ // do possible rehash if(_size+1 > _critical_size) _rehash(); bool ok; int i; - _probe(key, ok, i); + _probe_1(key, ok, i); if(!ok) { _size++; _items[i].first = key; @@ -91,20 +91,20 @@ namespace pkpy{ PyObject* Dict::try_get(PyObject* key) const{ bool ok; int i; - _probe(key, ok, i); + _probe_0(key, ok, i); if(!ok) return nullptr; return _items[i].second; } bool Dict::contains(PyObject* key) const{ bool ok; int i; - _probe(key, ok, i); + _probe_0(key, ok, i); return ok; } bool Dict::erase(PyObject* key){ bool ok; int i; - _probe(key, ok, i); + _probe_0(key, ok, i); if(!ok) return false; _items[i].first = nullptr; // _items[i].second = PY_DELETED_SLOT; // do not change .second if it is not NULL, it means the slot is occupied by a deleted item diff --git a/src/vm.cpp b/src/vm.cpp index 51bd578b..9bbd2f67 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -1031,7 +1031,7 @@ void VM::bind__len__(Type type, i64 (*f)(VM*, PyObject*)){ PK_OBJ_GET(NativeFunc, nf).set_userdata(f); } -void Dict::_probe(PyObject *key, bool &ok, int &i) const{ +void Dict::_probe_0(PyObject *key, bool &ok, int &i) const{ ok = false; i64 hash = vm->py_hash(key); i = hash & _mask; @@ -1048,6 +1048,16 @@ void Dict::_probe(PyObject *key, bool &ok, int &i) const{ } } +void Dict::_probe_1(PyObject *key, bool &ok, int &i) const{ + ok = false; + i = vm->py_hash(key) & _mask; + while(_items[i].first != nullptr) { + if(vm->py_equals(_items[i].first, key)) { ok = true; break; } + // https://github.com/python/cpython/blob/3.8/Objects/dictobject.c#L166 + i = ((5*i) + 1) & _mask; + } +} + void CodeObjectSerializer::write_object(VM *vm, PyObject *obj){ if(is_int(obj)) write_int(_CAST(i64, obj)); else if(is_float(obj)) write_float(_CAST(f64, obj));