allow __eq__ returns non-bool

This commit is contained in:
blueloveTH 2024-07-21 20:35:22 +08:00
parent 9d9674d171
commit 29a989f09a
5 changed files with 35 additions and 37 deletions

View File

@ -288,12 +288,15 @@ bool KeyError(py_Ref key);
/// Returns -1 if an error occurred. /// Returns -1 if an error occurred.
int py_bool(const py_Ref val); int py_bool(const py_Ref val);
int py_eq(const py_Ref, const py_Ref); #define py_eq(lhs, rhs) py_binaryop(lhs, rhs, __eq__, __eq__)
int py_ne(const py_Ref, const py_Ref); #define py_ne(lhs, rhs) py_binaryop(lhs, rhs, __ne__, __ne__)
int py_le(const py_Ref, const py_Ref); #define py_lt(lhs, rhs) py_binaryop(lhs, rhs, __lt__, __gt__)
int py_lt(const py_Ref, const py_Ref); #define py_le(lhs, rhs) py_binaryop(lhs, rhs, __le__, __ge__)
int py_ge(const py_Ref, const py_Ref); #define py_gt(lhs, rhs) py_binaryop(lhs, rhs, __gt__, __lt__)
int py_gt(const py_Ref, const py_Ref); #define py_ge(lhs, rhs) py_binaryop(lhs, rhs, __ge__, __le__)
int py_equal(const py_Ref lhs, const py_Ref rhs);
int py_less(const py_Ref lhs, const py_Ref rhs);
bool py_hash(const py_Ref, py_i64* out); bool py_hash(const py_Ref, py_i64* out);

View File

@ -26,7 +26,7 @@ py_TValue* pk_arrayview(py_Ref self, int* length) {
int pk_arrayeq(py_TValue* lhs, int lhs_length, py_TValue* rhs, int rhs_length) { int pk_arrayeq(py_TValue* lhs, int lhs_length, py_TValue* rhs, int rhs_length) {
if(lhs_length != rhs_length) return false; if(lhs_length != rhs_length) return false;
for(int i = 0; i < lhs_length; i++) { for(int i = 0; i < lhs_length; i++) {
int res = py_eq(lhs + i, rhs + i); int res = py_equal(lhs + i, rhs + i);
if(res == -1) return -1; if(res == -1) return -1;
if(!res) return false; if(!res) return false;
} }

View File

@ -52,7 +52,7 @@ static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) {
int idx2 = self->indices[idx]._[i]; int idx2 = self->indices[idx]._[i];
if(idx2 == -1) continue; if(idx2 == -1) continue;
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key); int res = py_equal(&entry->key, key);
if(res == 1) { if(res == 1) {
*out = entry; *out = entry;
return true; return true;
@ -150,7 +150,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
} }
// update existing entry // update existing entry
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key); int res = py_equal(&entry->key, key);
if(res == 1) { if(res == 1) {
entry->val = *val; entry->val = *val;
return true; return true;
@ -174,7 +174,7 @@ static bool Dict__pop(Dict* self, py_Ref key) {
int idx2 = self->indices[idx]._[i]; int idx2 = self->indices[idx]._[i];
if(idx2 == -1) continue; if(idx2 == -1) continue;
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key); int res = py_equal(&entry->key, key);
if(res == 1) { if(res == 1) {
*py_retval() = entry->val; *py_retval() = entry->val;
py_newnil(&entry->key); py_newnil(&entry->key);
@ -318,7 +318,7 @@ static bool _py_dict__eq__(int argc, py_Ref argv) {
py_newbool(py_retval(), false); py_newbool(py_retval(), false);
return true; return true;
} }
int res = py_eq(&entry->val, &other_entry->val); int res = py_equal(&entry->val, &other_entry->val);
if(res == -1) return false; if(res == -1) return false;
if(!res) { if(!res) {
py_newbool(py_retval(), false); py_newbool(py_retval(), false);

View File

@ -258,7 +258,7 @@ static bool _py_list__count(int argc, py_Ref argv) {
PY_CHECK_ARGC(2); PY_CHECK_ARGC(2);
int count = 0; int count = 0;
for(int i = 0; i < py_list__len(py_arg(0)); i++) { for(int i = 0; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false; if(res == -1) return false;
if(res) count++; if(res) count++;
} }
@ -290,7 +290,7 @@ static bool _py_list__index(int argc, py_Ref argv) {
start = py_toint(py_arg(2)); start = py_toint(py_arg(2));
} }
for(int i = start; i < py_list__len(py_arg(0)); i++) { for(int i = start; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false; if(res == -1) return false;
if(res) { if(res) {
py_newint(py_retval(), i); py_newint(py_retval(), i);
@ -311,7 +311,7 @@ static bool _py_list__reverse(int argc, py_Ref argv) {
static bool _py_list__remove(int argc, py_Ref argv) { static bool _py_list__remove(int argc, py_Ref argv) {
PY_CHECK_ARGC(2); PY_CHECK_ARGC(2);
for(int i = 0; i < py_list__len(py_arg(0)); i++) { for(int i = 0; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false; if(res == -1) return false;
if(res) { if(res) {
py_list__delitem(py_arg(0), i); py_list__delitem(py_arg(0), i);
@ -354,7 +354,7 @@ static bool _py_list__insert(int argc, py_Ref argv) {
} }
static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) { static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
if(!key) return py_lt(a, b); if(!key) return py_less(a, b);
pk_VM* vm = pk_current_vm; pk_VM* vm = pk_current_vm;
// project a // project a
py_push(key); py_push(key);
@ -372,7 +372,7 @@ static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
bool ok = pk_stack_binaryop(vm, __lt__, __gt__); bool ok = pk_stack_binaryop(vm, __lt__, __gt__);
if(!ok) return -1; if(!ok) return -1;
py_shrink(2); py_shrink(2);
return py_tobool(py_retval()); return py_bool(py_retval());
} }
// sort(self, key=None, reverse=False) // sort(self, key=None, reverse=False)

View File

@ -26,14 +26,14 @@ int py_bool(const py_Ref val) {
default: { default: {
py_Ref tmp = py_tpfindmagic(val->type, __bool__); py_Ref tmp = py_tpfindmagic(val->type, __bool__);
if(tmp) { if(tmp) {
bool ok = py_call(tmp, 1, val); if(!py_call(tmp, 1, val)) return -1;
if(!ok) return -1; if(!py_checkbool(py_retval())) return -1;
return py_tobool(py_retval()); return py_tobool(py_retval());
} else { } else {
tmp = py_tpfindmagic(val->type, __len__); tmp = py_tpfindmagic(val->type, __len__);
if(tmp) { if(tmp) {
bool ok = py_call(tmp, 1, val); if(!py_call(tmp, 1, val)) return -1;
if(!ok) return -1; if(!py_checkint(py_retval())) return -1;
return py_toint(py_retval()); return py_toint(py_retval());
} else { } else {
return 1; // True return 1; // True
@ -51,8 +51,8 @@ bool py_hash(const py_Ref val, int64_t* out) {
if(py_isnone(_hash)) break; if(py_isnone(_hash)) break;
py_Ref _eq = &types[t].magic[__eq__]; py_Ref _eq = &types[t].magic[__eq__];
if(!py_isnil(_hash) && !py_isnil(_eq)) { if(!py_isnil(_hash) && !py_isnil(_eq)) {
bool ok = py_call(_hash, 1, val); if(!py_call(_hash, 1, val)) return false;
if(!ok) return false; if(!py_checkint(py_retval())) return false;
*out = py_toint(py_retval()); *out = py_toint(py_retval());
return true; return true;
} }
@ -72,8 +72,7 @@ int py_next(const py_Ref val) {
vm->is_stopiteration = false; vm->is_stopiteration = false;
py_Ref tmp = py_tpfindmagic(val->type, __next__); py_Ref tmp = py_tpfindmagic(val->type, __next__);
if(!tmp) return TypeError("'%t' object is not an iterator", val->type); if(!tmp) return TypeError("'%t' object is not an iterator", val->type);
bool ok = py_call(tmp, 1, val); if(py_call(tmp, 1, val)) return true;
if(ok) return true;
return vm->is_stopiteration ? 0 : -1; return vm->is_stopiteration ? 0 : -1;
} }
@ -201,16 +200,12 @@ bool py_delitem(py_Ref self, const py_Ref key) {
return ok; return ok;
} }
#define COMPARE_OP_IMPL(name, op, rop) \ int py_equal(const py_Ref lhs, const py_Ref rhs){
int py_##name(const py_Ref lhs, const py_Ref rhs) { \ if(!py_eq(lhs, rhs)) return -1;
bool ok = py_binaryop(lhs, rhs, op, rop); \ return py_bool(py_retval());
if(!ok) return -1; \
return py_tobool(py_retval()); \
} }
COMPARE_OP_IMPL(eq, __eq__, __eq__) int py_less(const py_Ref lhs, const py_Ref rhs){
COMPARE_OP_IMPL(ne, __ne__, __ne__) if(!py_lt(lhs, rhs)) return -1;
COMPARE_OP_IMPL(lt, __lt__, __gt__) return py_bool(py_retval());
COMPARE_OP_IMPL(le, __le__, __ge__) }
COMPARE_OP_IMPL(gt, __gt__, __lt__)
COMPARE_OP_IMPL(ge, __ge__, __le__)