mirror of
https://github.com/pocketpy/pocketpy
synced 2025-10-20 11:30:18 +00:00
add pybind11 implementation for module reload
This commit is contained in:
parent
03a780dd59
commit
5bebf3e2df
@ -19,6 +19,11 @@ class module_ : public object {
|
||||
return steal<module_>(m);
|
||||
}
|
||||
|
||||
void reload() {
|
||||
bool ok = py_importlib_reload(ptr());
|
||||
if(!ok) { throw error_already_set(); }
|
||||
}
|
||||
|
||||
module_ def_submodule(const char* name, const char* doc = nullptr) {
|
||||
// auto package = (attr("__package__").cast<std::string>() += ".") +=
|
||||
// attr("__name__").cast<std::string_view>();
|
||||
|
@ -79,4 +79,62 @@ TEST_F(PYBIND11_TEST, dynamic_module) {
|
||||
EXPECT_EQ(math.attr("sub")(4, 3).cast<int>(), 1);
|
||||
}
|
||||
|
||||
struct import_callback {
|
||||
using cb_type = decltype(py_callbacks()->importfile);
|
||||
|
||||
import_callback() {
|
||||
assert(_importfile == nullptr);
|
||||
_importfile = py_callbacks()->importfile;
|
||||
py_callbacks()->importfile = importfile;
|
||||
};
|
||||
|
||||
~import_callback() {
|
||||
assert(_importfile != nullptr);
|
||||
py_callbacks()->importfile = _importfile;
|
||||
_importfile = nullptr;
|
||||
};
|
||||
|
||||
static char* importfile(const char* path) {
|
||||
if(value.empty()) return _importfile(path);
|
||||
// +1 for the null terminator
|
||||
char* cstr = new char[value.size() + 1];
|
||||
|
||||
std::strcpy(cstr, value.c_str());
|
||||
return cstr;
|
||||
}
|
||||
|
||||
static std::string value;
|
||||
|
||||
private:
|
||||
static cb_type _importfile;
|
||||
};
|
||||
|
||||
import_callback::cb_type import_callback::_importfile = nullptr;
|
||||
std::string import_callback::value = "";
|
||||
|
||||
TEST_F(PYBIND11_TEST, reload_module) {
|
||||
import_callback cb;
|
||||
|
||||
import_callback::value = "value = 1\n";
|
||||
auto mod = py::module::import("reload_module");
|
||||
EXPECT_EQ(mod.attr("value").cast<int>(), 1);
|
||||
|
||||
import_callback::value = "value = 2\n";
|
||||
mod.reload();
|
||||
EXPECT_EQ(mod.attr("value").cast<int>(), 2);
|
||||
|
||||
import_callback::value = "raise ValueError()";
|
||||
// Reload in Python raises a ValueError
|
||||
py::exec(
|
||||
"import importlib\nimport reload_module\ntry:\n importlib.reload(reload_module)\nexcept ValueError:\n pass");
|
||||
|
||||
// Reload in C++ raises a ValueError
|
||||
try {
|
||||
mod.reload();
|
||||
} catch(py::error_already_set& e) {
|
||||
if(e.match(tp_ValueError)) { return; }
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user