From 53ea790caf50c6ca1d29247080c829d185b9e5c2 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 22 Aug 2023 23:34:39 +0800 Subject: [PATCH] reimpl `py_hash` --- docs/C-API/introduction.md | 2 +- src/pocketpy.cpp | 21 +++++++-------------- src/vm.cpp | 23 +++++++++++++++++++++-- tests/99_builtin_func.py | 14 +++++++------- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/docs/C-API/introduction.md b/docs/C-API/introduction.md index 33107251..1351be9f 100644 --- a/docs/C-API/introduction.md +++ b/docs/C-API/introduction.md @@ -16,7 +16,7 @@ Methods return a `bool` indicating if the operation succeeded or not. Special thanks for [@koltenpearson](https://github.com/koltenpearson)'s contribution. !!! -**C-APIs are always stable and backward compatible** in order to support the second use case. +C-APIs are always stable and backward compatible. !!! ### Basic functions diff --git a/src/pocketpy.cpp b/src/pocketpy.cpp index 3999e7df..b7f214b7 100644 --- a/src/pocketpy.cpp +++ b/src/pocketpy.cpp @@ -309,7 +309,6 @@ void init_builtins(VM* _vm) { }); _vm->bind__eq__(_vm->tp_object, [](VM* vm, PyObject* lhs, PyObject* rhs) { return VAR(lhs == rhs); }); - _vm->bind__hash__(_vm->tp_object, [](VM* vm, PyObject* obj) { return PK_BITS(obj); }); _vm->cached_object__new__ = _vm->bind_constructor<1>("object", [](VM* vm, ArgsView args) { vm->check_non_tagged_type(args[0], vm->tp_type); @@ -754,11 +753,6 @@ void init_builtins(VM* _vm) { _vm->bind_method<0>("list", "copy", PK_LAMBDA(VAR(_CAST(List, args[0])))); - _vm->bind__hash__(_vm->tp_list, [](VM* vm, PyObject* obj) { - vm->TypeError("unhashable type: 'list'"); - return (i64)0; - }); - _vm->bind__add__(_vm->tp_list, [](VM* vm, PyObject* lhs, PyObject* rhs) { const List& self = _CAST(List&, lhs); const List& other = CAST(List&, rhs); @@ -997,9 +991,13 @@ void init_builtins(VM* _vm) { return (i64)_CAST(MappingProxy&, obj).attr().size(); }); - _vm->bind__hash__(_vm->tp_mappingproxy, [](VM* vm, PyObject* obj) { - vm->TypeError("unhashable type: 'mappingproxy'"); - return (i64)0; + _vm->bind__eq__(_vm->tp_mappingproxy, [](VM* vm, PyObject* obj, PyObject* other){ + MappingProxy& a = _CAST(MappingProxy&, obj); + if(!is_non_tagged_type(other, vm->tp_mappingproxy)){ + return vm->NotImplemented; + } + MappingProxy& b = _CAST(MappingProxy&, other); + return VAR(a.obj == b.obj); }); _vm->bind__getitem__(_vm->tp_mappingproxy, [](VM* vm, PyObject* obj, PyObject* index) { @@ -1057,11 +1055,6 @@ void init_builtins(VM* _vm) { _vm->bind__len__(_vm->tp_dict, [](VM* vm, PyObject* obj) { return (i64)_CAST(Dict&, obj).size(); }); - - _vm->bind__hash__(_vm->tp_dict, [](VM* vm, PyObject* obj) { - vm->TypeError("unhashable type: 'dict'"); - return (i64)0; - }); _vm->bind__getitem__(_vm->tp_dict, [](VM* vm, PyObject* obj, PyObject* index) { Dict& self = _CAST(Dict&, obj); diff --git a/src/vm.cpp b/src/vm.cpp index fe3b6b80..13ba14bd 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -350,10 +350,29 @@ void VM::parse_int_slice(const Slice& s, int length, int& start, int& stop, int& } i64 VM::py_hash(PyObject* obj){ + // https://docs.python.org/3.10/reference/datamodel.html#object.__hash__ const PyTypeInfo* ti = _inst_type_info(obj); if(ti->m__hash__) return ti->m__hash__(this, obj); - PyObject* ret = call_method(obj, __hash__); - return CAST(i64, ret); + + PyObject* self; + PyObject* f = get_unbound_method(obj, __hash__, &self, false); + if(f != nullptr){ + PyObject* ret = call_method(self, f); + return CAST(i64, ret); + } + // it flow reaches here, obj must not be the trivial `object` type + bool has_custom_eq = false; + if(ti->m__eq__) has_custom_eq = true; + else{ + f = get_unbound_method(obj, __eq__, &self, false); + has_custom_eq = f != _t(tp_object)->attr(__eq__); + } + if(has_custom_eq){ + TypeError(fmt("unhashable type: ", ti->name.escape())); + return 0; + }else{ + return PK_BITS(obj); + } } PyObject* VM::format(Str spec, PyObject* obj){ diff --git a/tests/99_builtin_func.py b/tests/99_builtin_func.py index c559c64c..0e53819e 100644 --- a/tests/99_builtin_func.py +++ b/tests/99_builtin_func.py @@ -207,12 +207,12 @@ def f(): assert type(hash(a)) is int # 测试不可哈希对象 -# try: -# hash({1:1}) -# print('未能拦截错误') -# exit(1) -# except: -# pass +try: + hash({1:1}) + print('未能拦截错误') + exit(1) +except: + pass try: hash([1]) @@ -842,7 +842,7 @@ try: hash(my_mappingproxy) print('未能拦截错误, 在测试 mappingproxy.__hash__') exit(1) -except: +except TypeError: pass # 未完全测试准确性-----------------------------------------------