allow __eq__ returns non-bool

This commit is contained in:
blueloveTH 2024-07-21 20:35:22 +08:00
parent 9d9674d171
commit 29a989f09a
5 changed files with 35 additions and 37 deletions

View File

@ -288,12 +288,15 @@ bool KeyError(py_Ref key);
/// Returns -1 if an error occurred.
int py_bool(const py_Ref val);
int py_eq(const py_Ref, const py_Ref);
int py_ne(const py_Ref, const py_Ref);
int py_le(const py_Ref, const py_Ref);
int py_lt(const py_Ref, const py_Ref);
int py_ge(const py_Ref, const py_Ref);
int py_gt(const py_Ref, const py_Ref);
#define py_eq(lhs, rhs) py_binaryop(lhs, rhs, __eq__, __eq__)
#define py_ne(lhs, rhs) py_binaryop(lhs, rhs, __ne__, __ne__)
#define py_lt(lhs, rhs) py_binaryop(lhs, rhs, __lt__, __gt__)
#define py_le(lhs, rhs) py_binaryop(lhs, rhs, __le__, __ge__)
#define py_gt(lhs, rhs) py_binaryop(lhs, rhs, __gt__, __lt__)
#define py_ge(lhs, rhs) py_binaryop(lhs, rhs, __ge__, __le__)
int py_equal(const py_Ref lhs, const py_Ref rhs);
int py_less(const py_Ref lhs, const py_Ref rhs);
bool py_hash(const py_Ref, py_i64* out);

View File

@ -26,7 +26,7 @@ py_TValue* pk_arrayview(py_Ref self, int* length) {
int pk_arrayeq(py_TValue* lhs, int lhs_length, py_TValue* rhs, int rhs_length) {
if(lhs_length != rhs_length) return false;
for(int i = 0; i < lhs_length; i++) {
int res = py_eq(lhs + i, rhs + i);
int res = py_equal(lhs + i, rhs + i);
if(res == -1) return -1;
if(!res) return false;
}

View File

@ -52,7 +52,7 @@ static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) {
int idx2 = self->indices[idx]._[i];
if(idx2 == -1) continue;
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key);
int res = py_equal(&entry->key, key);
if(res == 1) {
*out = entry;
return true;
@ -150,7 +150,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
}
// update existing entry
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key);
int res = py_equal(&entry->key, key);
if(res == 1) {
entry->val = *val;
return true;
@ -174,7 +174,7 @@ static bool Dict__pop(Dict* self, py_Ref key) {
int idx2 = self->indices[idx]._[i];
if(idx2 == -1) continue;
DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
int res = py_eq(&entry->key, key);
int res = py_equal(&entry->key, key);
if(res == 1) {
*py_retval() = entry->val;
py_newnil(&entry->key);
@ -318,7 +318,7 @@ static bool _py_dict__eq__(int argc, py_Ref argv) {
py_newbool(py_retval(), false);
return true;
}
int res = py_eq(&entry->val, &other_entry->val);
int res = py_equal(&entry->val, &other_entry->val);
if(res == -1) return false;
if(!res) {
py_newbool(py_retval(), false);

View File

@ -258,7 +258,7 @@ static bool _py_list__count(int argc, py_Ref argv) {
PY_CHECK_ARGC(2);
int count = 0;
for(int i = 0; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false;
if(res) count++;
}
@ -290,7 +290,7 @@ static bool _py_list__index(int argc, py_Ref argv) {
start = py_toint(py_arg(2));
}
for(int i = start; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false;
if(res) {
py_newint(py_retval(), i);
@ -311,7 +311,7 @@ static bool _py_list__reverse(int argc, py_Ref argv) {
static bool _py_list__remove(int argc, py_Ref argv) {
PY_CHECK_ARGC(2);
for(int i = 0; i < py_list__len(py_arg(0)); i++) {
int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
if(res == -1) return false;
if(res) {
py_list__delitem(py_arg(0), i);
@ -354,7 +354,7 @@ static bool _py_list__insert(int argc, py_Ref argv) {
}
static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
if(!key) return py_lt(a, b);
if(!key) return py_less(a, b);
pk_VM* vm = pk_current_vm;
// project a
py_push(key);
@ -372,7 +372,7 @@ static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
bool ok = pk_stack_binaryop(vm, __lt__, __gt__);
if(!ok) return -1;
py_shrink(2);
return py_tobool(py_retval());
return py_bool(py_retval());
}
// sort(self, key=None, reverse=False)
@ -436,7 +436,7 @@ py_Type pk_list__register() {
return type;
}
void pk_list__mark(void* ud, void (*marker)(py_TValue*)){
void pk_list__mark(void* ud, void (*marker)(py_TValue*)) {
List* self = ud;
for(int i = 0; i < self->count; i++) {
marker(c11__at(py_TValue, self, i));

View File

@ -26,14 +26,14 @@ int py_bool(const py_Ref val) {
default: {
py_Ref tmp = py_tpfindmagic(val->type, __bool__);
if(tmp) {
bool ok = py_call(tmp, 1, val);
if(!ok) return -1;
if(!py_call(tmp, 1, val)) return -1;
if(!py_checkbool(py_retval())) return -1;
return py_tobool(py_retval());
} else {
tmp = py_tpfindmagic(val->type, __len__);
if(tmp) {
bool ok = py_call(tmp, 1, val);
if(!ok) return -1;
if(!py_call(tmp, 1, val)) return -1;
if(!py_checkint(py_retval())) return -1;
return py_toint(py_retval());
} else {
return 1; // True
@ -51,8 +51,8 @@ bool py_hash(const py_Ref val, int64_t* out) {
if(py_isnone(_hash)) break;
py_Ref _eq = &types[t].magic[__eq__];
if(!py_isnil(_hash) && !py_isnil(_eq)) {
bool ok = py_call(_hash, 1, val);
if(!ok) return false;
if(!py_call(_hash, 1, val)) return false;
if(!py_checkint(py_retval())) return false;
*out = py_toint(py_retval());
return true;
}
@ -72,8 +72,7 @@ int py_next(const py_Ref val) {
vm->is_stopiteration = false;
py_Ref tmp = py_tpfindmagic(val->type, __next__);
if(!tmp) return TypeError("'%t' object is not an iterator", val->type);
bool ok = py_call(tmp, 1, val);
if(ok) return true;
if(py_call(tmp, 1, val)) return true;
return vm->is_stopiteration ? 0 : -1;
}
@ -201,16 +200,12 @@ bool py_delitem(py_Ref self, const py_Ref key) {
return ok;
}
#define COMPARE_OP_IMPL(name, op, rop) \
int py_##name(const py_Ref lhs, const py_Ref rhs) { \
bool ok = py_binaryop(lhs, rhs, op, rop); \
if(!ok) return -1; \
return py_tobool(py_retval()); \
}
int py_equal(const py_Ref lhs, const py_Ref rhs){
if(!py_eq(lhs, rhs)) return -1;
return py_bool(py_retval());
}
COMPARE_OP_IMPL(eq, __eq__, __eq__)
COMPARE_OP_IMPL(ne, __ne__, __ne__)
COMPARE_OP_IMPL(lt, __lt__, __gt__)
COMPARE_OP_IMPL(le, __le__, __ge__)
COMPARE_OP_IMPL(gt, __gt__, __lt__)
COMPARE_OP_IMPL(ge, __ge__, __le__)
int py_less(const py_Ref lhs, const py_Ref rhs){
if(!py_lt(lhs, rhs)) return -1;
return py_bool(py_retval());
}