diff --git a/src/obj.h b/src/obj.h index eb6bdcc2..18752ac8 100644 --- a/src/obj.h +++ b/src/obj.h @@ -14,14 +14,26 @@ class VM; typedef PyObject* (*NativeFuncC)(VM*, ArgsView); typedef int (*LuaStyleFuncC)(VM*); +union UserData{ + void* _p; + void (*_fp)(void); + char _char; + int _int; + float _float; + bool _bool; +}; + struct NativeFunc { NativeFuncC f; int argc; // DONOT include self bool method; // this is designed for lua style C bindings - // access it via `CAST(NativeFunc&, args[-2])._lua_f` + // access it via `_CAST(NativeFunc&, args[-2])._lua_f` + // (-2) or (-1) depends on the calling convention LuaStyleFuncC _lua_f; + + UserData userdata; NativeFunc(NativeFuncC f, int argc, bool method) : f(f), argc(argc), method(method), _lua_f(nullptr) {} PyObject* operator()(VM* vm, ArgsView args) const; @@ -389,4 +401,14 @@ struct Py_ final: PyObject { }; +template +T lambda_get_fp(ArgsView args){ + void (*f)(); + if(args[-1] != PY_NULL) f = OBJ_GET(NativeFunc, args[-1]).userdata._fp; + else f = OBJ_GET(NativeFunc, args[-2]).userdata._fp; + return reinterpret_cast(f); +} + + + } // namespace pkpy \ No newline at end of file diff --git a/src/vm.h b/src/vm.h index 65bbbbad..e06e1ead 100644 --- a/src/vm.h +++ b/src/vm.h @@ -339,11 +339,12 @@ public: } #define BIND_UNARY_SPECIAL(name) \ - void bind##name(Type type, PyObject* (*f)(VM* vm, PyObject*)){ \ + void bind##name(Type type, PyObject* (*f)(VM*, PyObject*)){ \ _all_types[type].m##name = f; \ - bind_method<0>(_t(type), #name, [](VM* vm, ArgsView args){ \ - return vm->_inst_type_info(args[0])->m##name(vm, args[0]); \ + PyObject* nf = bind_method<0>(_t(type), #name, [](VM* vm, ArgsView args){ \ + return lambda_get_fp(args)(vm, args[0]); \ }); \ + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); \ } BIND_UNARY_SPECIAL(__repr__) @@ -360,13 +361,14 @@ public: #define BIND_LOGICAL_SPECIAL(name) \ - void bind##name(Type type, bool (*f)(VM* vm, PyObject* lhs, PyObject* rhs)){ \ + void bind##name(Type type, bool (*f)(VM*, PyObject*, PyObject*)){ \ PyObject* obj = _t(type); \ _all_types[type].m##name = f; \ - bind_method<1>(obj, #name, [](VM* vm, ArgsView args){ \ - bool ok = vm->_inst_type_info(args[0])->m##name(vm, args[0], args[1]); \ + PyObject* nf = bind_method<1>(obj, #name, [](VM* vm, ArgsView args){ \ + bool ok = lambda_get_fp(args)(vm, args[0], args[1]); \ return ok ? vm->True : vm->False; \ }); \ + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); \ } BIND_LOGICAL_SPECIAL(__eq__) @@ -381,12 +383,13 @@ public: #define BIND_BINARY_SPECIAL(name) \ - void bind##name(Type type, PyObject* (*f)(VM* vm, PyObject* lhs, PyObject* rhs)){ \ + void bind##name(Type type, PyObject* (*f)(VM*, PyObject*, PyObject*)){ \ PyObject* obj = _t(type); \ _all_types[type].m##name = f; \ - bind_method<1>(obj, #name, [](VM* vm, ArgsView args){ \ - return vm->_inst_type_info(args[0])->m##name(vm, args[0], args[1]); \ + PyObject* nf = bind_method<1>(obj, #name, [](VM* vm, ArgsView args){ \ + return lambda_get_fp(args)(vm, args[0], args[1]); \ }); \ + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); \ } BIND_BINARY_SPECIAL(__add__) @@ -406,30 +409,33 @@ public: #undef BIND_BINARY_SPECIAL - void bind__getitem__(Type type, PyObject* (*f)(VM* vm, PyObject* lhs, PyObject* rhs)){ + void bind__getitem__(Type type, PyObject* (*f)(VM*, PyObject*, PyObject*)){ PyObject* obj = _t(type); _all_types[type].m__getitem__ = f; - bind_method<1>(obj, "__getitem__", [](VM* vm, ArgsView args){ - return vm->_inst_type_info(args[0])->m__getitem__(vm, args[0], args[1]); + PyObject* nf = bind_method<1>(obj, "__getitem__", [](VM* vm, ArgsView args){ + return lambda_get_fp(args)(vm, args[0], args[1]); }); + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); } - void bind__setitem__(Type type, void (*f)(VM* vm, PyObject* lhs, PyObject* rhs, PyObject* value)){ + void bind__setitem__(Type type, void (*f)(VM*, PyObject*, PyObject*, PyObject*)){ PyObject* obj = _t(type); _all_types[type].m__setitem__ = f; - bind_method<2>(obj, "__setitem__", [](VM* vm, ArgsView args){ - vm->_inst_type_info(args[0])->m__setitem__(vm, args[0], args[1], args[2]); + PyObject* nf = bind_method<2>(obj, "__setitem__", [](VM* vm, ArgsView args){ + lambda_get_fp(args)(vm, args[0], args[1], args[2]); return vm->None; }); + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); } - void bind__delitem__(Type type, void (*f)(VM* vm, PyObject* lhs, PyObject* rhs)){ + void bind__delitem__(Type type, void (*f)(VM*, PyObject*, PyObject*)){ PyObject* obj = _t(type); _all_types[type].m__delitem__ = f; - bind_method<1>(obj, "__delitem__", [](VM* vm, ArgsView args){ - vm->_inst_type_info(args[0])->m__delitem__(vm, args[0], args[1]); + PyObject* nf = bind_method<1>(obj, "__delitem__", [](VM* vm, ArgsView args){ + lambda_get_fp(args)(vm, args[0], args[1]); return vm->None; }); + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); } bool py_equals(PyObject* lhs, PyObject* rhs){ @@ -447,32 +453,32 @@ public: } template - void bind_func(Str type, Str name, NativeFuncC fn) { - bind_func(_find_type_object(type), name, fn); + PyObject* bind_func(Str type, Str name, NativeFuncC fn) { + return bind_func(_find_type_object(type), name, fn); } template - void bind_method(Str type, Str name, NativeFuncC fn) { - bind_method(_find_type_object(type), name, fn); + PyObject* bind_method(Str type, Str name, NativeFuncC fn) { + return bind_method(_find_type_object(type), name, fn); } template - void bind_constructor(__T&& type, NativeFuncC fn) { + PyObject* bind_constructor(__T&& type, NativeFuncC fn) { static_assert(ARGC==-1 || ARGC>=1); - bind_func(std::forward<__T>(type), "__new__", fn); + return bind_func(std::forward<__T>(type), "__new__", fn); } template - void bind_default_constructor(__T&& type) { - bind_constructor<-1>(std::forward<__T>(type), [](VM* vm, ArgsView args){ + PyObject* bind_default_constructor(__T&& type) { + return bind_constructor<-1>(std::forward<__T>(type), [](VM* vm, ArgsView args){ Type t = OBJ_GET(Type, args[0]); return vm->heap.gcnew(t, T()); }); } template - void bind_builtin_func(Str name, NativeFuncC fn) { - bind_func(builtins, name, fn); + PyObject* bind_builtin_func(Str name, NativeFuncC fn) { + return bind_func(builtins, name, fn); } int normalized_index(int index, int size){ @@ -581,9 +587,9 @@ public: PyObject* format(Str, PyObject*); void setattr(PyObject* obj, StrName name, PyObject* value); template - void bind_method(PyObject*, Str, NativeFuncC); + PyObject* bind_method(PyObject*, Str, NativeFuncC); template - void bind_func(PyObject*, Str, NativeFuncC); + PyObject* bind_func(PyObject*, Str, NativeFuncC); void _error(Exception); PyObject* _run_top_frame(); void post_init(); @@ -594,6 +600,9 @@ inline PyObject* NativeFunc::operator()(VM* vm, ArgsView args) const{ if(argc != -1 && args_size != argc) { vm->TypeError(fmt("expected ", argc, " arguments, but got ", args_size)); } +#if DEBUG_EXTRA_CHECK + if(f == nullptr) FATAL_ERROR(); +#endif return f(vm, args); } @@ -1351,20 +1360,24 @@ inline void VM::setattr(PyObject* obj, StrName name, PyObject* value){ } template -void VM::bind_method(PyObject* obj, Str name, NativeFuncC fn) { +PyObject* VM::bind_method(PyObject* obj, Str name, NativeFuncC fn) { check_non_tagged_type(obj, tp_type); if(obj->attr().contains(name)){ throw std::runtime_error(fmt("bind_method() failed: ", name.escape(), " already exists")); } - obj->attr().set(name, VAR(NativeFunc(fn, ARGC, true))); + PyObject* nf = VAR(NativeFunc(fn, ARGC, true)); + obj->attr().set(name, nf); + return nf; } template -void VM::bind_func(PyObject* obj, Str name, NativeFuncC fn) { +PyObject* VM::bind_func(PyObject* obj, Str name, NativeFuncC fn) { if(obj->attr().contains(name)){ throw std::runtime_error(fmt("bind_func() failed: ", name.escape(), " already exists")); } - obj->attr().set(name, VAR(NativeFunc(fn, ARGC, false))); + PyObject* nf = VAR(NativeFunc(fn, ARGC, false)); + obj->attr().set(name, nf); + return nf; } inline void VM::_error(Exception e){ @@ -1411,22 +1424,24 @@ PyObject* PyArrayGetItem(VM* vm, PyObject* obj, PyObject* index){ return self[i]; } -inline void VM::bind__hash__(Type type, i64 (*f)(VM* vm, PyObject*)){ +inline void VM::bind__hash__(Type type, i64 (*f)(VM*, PyObject*)){ PyObject* obj = _t(type); _all_types[type].m__hash__ = f; - bind_method<0>(obj, "__hash__", [](VM* vm, ArgsView args){ - i64 ret = vm->_inst_type_info(args[0])->m__hash__(vm, args[0]); + PyObject* nf = bind_method<0>(obj, "__hash__", [](VM* vm, ArgsView args){ + i64 ret = lambda_get_fp(args)(vm, args[0]); return VAR(ret); }); + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); } -inline void VM::bind__len__(Type type, i64 (*f)(VM* vm, PyObject*)){ +inline void VM::bind__len__(Type type, i64 (*f)(VM*, PyObject*)){ PyObject* obj = _t(type); _all_types[type].m__len__ = f; - bind_method<0>(obj, "__len__", [](VM* vm, ArgsView args){ - i64 ret = vm->_inst_type_info(args[0])->m__len__(vm, args[0]); + PyObject* nf = bind_method<0>(obj, "__len__", [](VM* vm, ArgsView args){ + i64 ret = lambda_get_fp(args)(vm, args[0]); return VAR(ret); }); + OBJ_GET(NativeFunc, nf).userdata._fp = reinterpret_cast(f); }