diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 8b19d635..1e4a87cd 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -288,12 +288,15 @@ bool KeyError(py_Ref key); /// Returns -1 if an error occurred. int py_bool(const py_Ref val); -int py_eq(const py_Ref, const py_Ref); -int py_ne(const py_Ref, const py_Ref); -int py_le(const py_Ref, const py_Ref); -int py_lt(const py_Ref, const py_Ref); -int py_ge(const py_Ref, const py_Ref); -int py_gt(const py_Ref, const py_Ref); +#define py_eq(lhs, rhs) py_binaryop(lhs, rhs, __eq__, __eq__) +#define py_ne(lhs, rhs) py_binaryop(lhs, rhs, __ne__, __ne__) +#define py_lt(lhs, rhs) py_binaryop(lhs, rhs, __lt__, __gt__) +#define py_le(lhs, rhs) py_binaryop(lhs, rhs, __le__, __ge__) +#define py_gt(lhs, rhs) py_binaryop(lhs, rhs, __gt__, __lt__) +#define py_ge(lhs, rhs) py_binaryop(lhs, rhs, __ge__, __le__) + +int py_equal(const py_Ref lhs, const py_Ref rhs); +int py_less(const py_Ref lhs, const py_Ref rhs); bool py_hash(const py_Ref, py_i64* out); diff --git a/src/public/py_array.c b/src/public/py_array.c index 62dd81ba..8c0047dc 100644 --- a/src/public/py_array.c +++ b/src/public/py_array.c @@ -26,7 +26,7 @@ py_TValue* pk_arrayview(py_Ref self, int* length) { int pk_arrayeq(py_TValue* lhs, int lhs_length, py_TValue* rhs, int rhs_length) { if(lhs_length != rhs_length) return false; for(int i = 0; i < lhs_length; i++) { - int res = py_eq(lhs + i, rhs + i); + int res = py_equal(lhs + i, rhs + i); if(res == -1) return -1; if(!res) return false; } diff --git a/src/public/py_dict.c b/src/public/py_dict.c index a6d19ef9..1f1f96b0 100644 --- a/src/public/py_dict.c +++ b/src/public/py_dict.c @@ -52,7 +52,7 @@ static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) { int idx2 = self->indices[idx]._[i]; if(idx2 == -1) continue; DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); - int res = py_eq(&entry->key, key); + int res = py_equal(&entry->key, key); if(res == 1) { *out = entry; return true; @@ -150,7 +150,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) { } // update existing entry DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); - int res = py_eq(&entry->key, key); + int res = py_equal(&entry->key, key); if(res == 1) { entry->val = *val; return true; @@ -174,7 +174,7 @@ static bool Dict__pop(Dict* self, py_Ref key) { int idx2 = self->indices[idx]._[i]; if(idx2 == -1) continue; DictEntry* entry = c11__at(DictEntry, &self->entries, idx2); - int res = py_eq(&entry->key, key); + int res = py_equal(&entry->key, key); if(res == 1) { *py_retval() = entry->val; py_newnil(&entry->key); @@ -318,7 +318,7 @@ static bool _py_dict__eq__(int argc, py_Ref argv) { py_newbool(py_retval(), false); return true; } - int res = py_eq(&entry->val, &other_entry->val); + int res = py_equal(&entry->val, &other_entry->val); if(res == -1) return false; if(!res) { py_newbool(py_retval(), false); diff --git a/src/public/py_list.c b/src/public/py_list.c index e4f985ea..18ba60fe 100644 --- a/src/public/py_list.c +++ b/src/public/py_list.c @@ -258,7 +258,7 @@ static bool _py_list__count(int argc, py_Ref argv) { PY_CHECK_ARGC(2); int count = 0; for(int i = 0; i < py_list__len(py_arg(0)); i++) { - int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); + int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1)); if(res == -1) return false; if(res) count++; } @@ -290,7 +290,7 @@ static bool _py_list__index(int argc, py_Ref argv) { start = py_toint(py_arg(2)); } for(int i = start; i < py_list__len(py_arg(0)); i++) { - int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); + int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1)); if(res == -1) return false; if(res) { py_newint(py_retval(), i); @@ -311,7 +311,7 @@ static bool _py_list__reverse(int argc, py_Ref argv) { static bool _py_list__remove(int argc, py_Ref argv) { PY_CHECK_ARGC(2); for(int i = 0; i < py_list__len(py_arg(0)); i++) { - int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1)); + int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1)); if(res == -1) return false; if(res) { py_list__delitem(py_arg(0), i); @@ -354,7 +354,7 @@ static bool _py_list__insert(int argc, py_Ref argv) { } static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) { - if(!key) return py_lt(a, b); + if(!key) return py_less(a, b); pk_VM* vm = pk_current_vm; // project a py_push(key); @@ -372,7 +372,7 @@ static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) { bool ok = pk_stack_binaryop(vm, __lt__, __gt__); if(!ok) return -1; py_shrink(2); - return py_tobool(py_retval()); + return py_bool(py_retval()); } // sort(self, key=None, reverse=False) @@ -436,7 +436,7 @@ py_Type pk_list__register() { return type; } -void pk_list__mark(void* ud, void (*marker)(py_TValue*)){ +void pk_list__mark(void* ud, void (*marker)(py_TValue*)) { List* self = ud; for(int i = 0; i < self->count; i++) { marker(c11__at(py_TValue, self, i)); diff --git a/src/public/py_ops.c b/src/public/py_ops.c index fc07e25e..7aa96305 100644 --- a/src/public/py_ops.c +++ b/src/public/py_ops.c @@ -26,14 +26,14 @@ int py_bool(const py_Ref val) { default: { py_Ref tmp = py_tpfindmagic(val->type, __bool__); if(tmp) { - bool ok = py_call(tmp, 1, val); - if(!ok) return -1; + if(!py_call(tmp, 1, val)) return -1; + if(!py_checkbool(py_retval())) return -1; return py_tobool(py_retval()); } else { tmp = py_tpfindmagic(val->type, __len__); if(tmp) { - bool ok = py_call(tmp, 1, val); - if(!ok) return -1; + if(!py_call(tmp, 1, val)) return -1; + if(!py_checkint(py_retval())) return -1; return py_toint(py_retval()); } else { return 1; // True @@ -51,8 +51,8 @@ bool py_hash(const py_Ref val, int64_t* out) { if(py_isnone(_hash)) break; py_Ref _eq = &types[t].magic[__eq__]; if(!py_isnil(_hash) && !py_isnil(_eq)) { - bool ok = py_call(_hash, 1, val); - if(!ok) return false; + if(!py_call(_hash, 1, val)) return false; + if(!py_checkint(py_retval())) return false; *out = py_toint(py_retval()); return true; } @@ -72,8 +72,7 @@ int py_next(const py_Ref val) { vm->is_stopiteration = false; py_Ref tmp = py_tpfindmagic(val->type, __next__); if(!tmp) return TypeError("'%t' object is not an iterator", val->type); - bool ok = py_call(tmp, 1, val); - if(ok) return true; + if(py_call(tmp, 1, val)) return true; return vm->is_stopiteration ? 0 : -1; } @@ -201,16 +200,12 @@ bool py_delitem(py_Ref self, const py_Ref key) { return ok; } -#define COMPARE_OP_IMPL(name, op, rop) \ - int py_##name(const py_Ref lhs, const py_Ref rhs) { \ - bool ok = py_binaryop(lhs, rhs, op, rop); \ - if(!ok) return -1; \ - return py_tobool(py_retval()); \ - } +int py_equal(const py_Ref lhs, const py_Ref rhs){ + if(!py_eq(lhs, rhs)) return -1; + return py_bool(py_retval()); +} -COMPARE_OP_IMPL(eq, __eq__, __eq__) -COMPARE_OP_IMPL(ne, __ne__, __ne__) -COMPARE_OP_IMPL(lt, __lt__, __gt__) -COMPARE_OP_IMPL(le, __le__, __ge__) -COMPARE_OP_IMPL(gt, __gt__, __lt__) -COMPARE_OP_IMPL(ge, __ge__, __le__) \ No newline at end of file +int py_less(const py_Ref lhs, const py_Ref rhs){ + if(!py_lt(lhs, rhs)) return -1; + return py_bool(py_retval()); +} \ No newline at end of file