some fix.

This commit is contained in:
ykiko 2024-06-20 12:03:45 +08:00
parent 4ba0b7d2e2
commit f162cd308a
4 changed files with 76 additions and 17 deletions

View File

@ -17,7 +17,7 @@ Or explicitly call `py::interpreter::initialize()` and `py::interpreter::finaliz
#include <pybind11/pybind11.h>
namespace py = pybind11;
PYBIND11_MODULE(example, m) {
PYBIND11_EMBEDDED_MODULE(example, m) {
m.def("add", [](int a, int b) {
return a + b;
});

View File

@ -47,11 +47,11 @@ public:
/// bind constructor
template <typename... Args, typename... Extra>
class_& def(init<Args...>, const Extra&... extra) {
class_& def(impl::constructor<Args...>, const Extra&... extra) {
if constexpr(!std::is_constructible_v<T, Args...>) {
static_assert(std::is_constructible_v<T, Args...>, "Invalid constructor arguments");
} else {
impl::bind_function(
impl::bind_function<true>(
*this,
"__init__",
[](T* self, Args... args) {
@ -63,17 +63,29 @@ public:
}
}
template <typename Fn, typename... Extra>
class_& def(impl::factory<Fn> factory, const Extra&... extra) {
using ret = callable_return_t<Fn>;
if constexpr(!std::is_same_v<T, ret>) {
static_assert(std::is_same_v<T, ret>, "Factory function must return the class type");
} else {
impl::bind_function<true>(*this, "__init__", factory.make(), pkpy::BindType::DEFAULT, extra...);
return *this;
}
}
/// bind member function
template <typename Fn, typename... Extra>
class_& def(const char* name, Fn&& f, const Extra&... extra) {
using first = std::tuple_element_t<0, callable_args_t<remove_cvref_t<Fn>>>;
constexpr bool is_first_base_of_v = std::is_base_of_v<remove_cvref_t<first>, T>;
using first = remove_cvref_t<std::tuple_element_t<0, callable_args_t<remove_cvref_t<Fn>>>>;
constexpr bool is_first_base_of_v = std::is_base_of_v<first, T> || std::is_same_v<first, T>;
if constexpr(!is_first_base_of_v) {
static_assert(is_first_base_of_v,
"If you want to bind member function, the first argument must be the base class");
} else {
impl::bind_function(*this, name, std::forward<Fn>(f), pkpy::BindType::DEFAULT, extra...);
impl::bind_function<true>(*this, name, std::forward<Fn>(f), pkpy::BindType::DEFAULT, extra...);
}
return *this;
@ -91,7 +103,7 @@ public:
/// bind static function
template <typename Fn, typename... Extra>
class_& def_static(const char* name, Fn&& f, const Extra&... extra) {
impl::bind_function(*this, name, std::forward<Fn>(f), pkpy::BindType::STATICMETHOD, extra...);
impl::bind_function<false>(*this, name, std::forward<Fn>(f), pkpy::BindType::STATICMETHOD, extra...);
return *this;
}
@ -163,6 +175,15 @@ public:
template <typename... Args>
enum_(const handle& scope, const char* name, Args&&... args) :
class_<T, Others...>(scope, name, std::forward<Args>(args)...) {
Base::def(init([](int value) {
return static_cast<T>(value);
}));
Base::def("__eq__", [](T& self, T& other) {
return self == other;
});
Base::def_property_readonly("value", [](T& self) {
return int_(static_cast<std::underlying_type_t<T>>(self));
});

View File

@ -7,8 +7,37 @@ namespace pybind11 {
// append the overload to the beginning of the overload list
struct prepend {};
namespace impl {
template <typename... Args>
struct init {};
struct constructor {};
template <typename Fn, typename Args = callable_args_t<Fn>>
struct factory;
template <typename Fn, typename... Args>
struct factory<Fn, std::tuple<Args...>> {
Fn fn;
auto make() {
using Self = callable_return_t<Fn>;
return [fn = std::move(fn)](Self* self, Args... args) {
new (self) Self(fn(args...));
};
}
};
} // namespace impl
template <typename... Args>
impl::constructor<Args...> init() {
return {};
}
template <typename Fn>
impl::factory<Fn> init(Fn&& fn) {
return {std::forward<Fn>(fn)};
}
// TODO: support more customized tags
//
@ -256,6 +285,7 @@ struct template_parser<Callable, std::tuple<Extras...>, std::tuple<Args...>, std
constexpr auto named_argc = types_count_v<arg, Extras...>;
constexpr auto normal_argc =
sizeof...(Args) - (arguments_info.args_pos != -1) - (arguments_info.kwargs_pos != -1);
static_assert(named_argc == 0 || named_argc == normal_argc,
"named arguments must be the same as the number of function arguments");
@ -419,24 +449,32 @@ inline auto _wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
return record(view).ptr();
}
template <typename Fn, typename... Extras>
template <bool is_method, typename Fn, typename... Extras>
handle bind_function(const handle& obj, const char* name, Fn&& fn, pkpy::BindType type, const Extras&... extras) {
// do not use cpp_function directly to avoid unnecessary reference count change
pkpy::PyVar var = obj.ptr();
cpp_function callable = var->attr().try_get(name);
function_record* record = nullptr;
// if the function is not bound yet, bind it
if(!callable) {
auto record = function_record(std::forward<Fn>(fn), extras...);
void* data = interpreter::take_ownership(std::move(record));
callable = interpreter::bind_func(var, name, -1, _wrapper, data);
if constexpr(is_method && types_count_v<arg, Extras...> > 0) {
// if the function is a method and has named arguments
// prepend self to the arguments list
record = new function_record(std::forward<Fn>(fn), arg("self"), extras...);
} else {
function_record* record = new function_record(std::forward<Fn>(fn), extras...);
record = new function_record(std::forward<Fn>(fn), extras...);
}
if(!callable) {
// if the function is not bound yet, bind it
void* data = interpreter::take_ownership(std::move(*record));
callable = interpreter::bind_func(var, name, -1, _wrapper, data, type);
} else {
// if the function is already bound, append the new record to the function
function_record* last = callable.get_userdata_as<function_record*>();
if constexpr((types_count_v<prepend, Extras...> != 0)) {
// if prepend is specified, append the new record to the beginning of the list
fn.set_userdata(record);
callable.set_userdata(record);
record->append(last);
} else {
// otherwise, append the new record to the end of the list

View File

@ -29,7 +29,7 @@ public:
template <typename Fn, typename... Extras>
module_& def(const char* name, Fn&& fn, const Extras... extras) {
impl::bind_function(*this, name, std::forward<Fn>(fn), pkpy::BindType::DEFAULT, extras...);
impl::bind_function<false>(*this, name, std::forward<Fn>(fn), pkpy::BindType::DEFAULT, extras...);
return *this;
}
};