diff --git a/include/pybind11/internal/module.h b/include/pybind11/internal/module.h index 71a5d3b6..f42e9e5b 100644 --- a/include/pybind11/internal/module.h +++ b/include/pybind11/internal/module.h @@ -19,6 +19,11 @@ class module_ : public object { return steal(m); } + void reload() { + bool ok = py_importlib_reload(ptr()); + if(!ok) { throw error_already_set(); } + } + module_ def_submodule(const char* name, const char* doc = nullptr) { // auto package = (attr("__package__").cast() += ".") += // attr("__name__").cast(); diff --git a/include/pybind11/tests/module.cpp b/include/pybind11/tests/module.cpp index 15e559f9..feaf8aa1 100644 --- a/include/pybind11/tests/module.cpp +++ b/include/pybind11/tests/module.cpp @@ -79,4 +79,62 @@ TEST_F(PYBIND11_TEST, dynamic_module) { EXPECT_EQ(math.attr("sub")(4, 3).cast(), 1); } +struct import_callback { + using cb_type = decltype(py_callbacks()->importfile); + + import_callback() { + assert(_importfile == nullptr); + _importfile = py_callbacks()->importfile; + py_callbacks()->importfile = importfile; + }; + + ~import_callback() { + assert(_importfile != nullptr); + py_callbacks()->importfile = _importfile; + _importfile = nullptr; + }; + + static char* importfile(const char* path) { + if(value.empty()) return _importfile(path); + // +1 for the null terminator + char* cstr = new char[value.size() + 1]; + + std::strcpy(cstr, value.c_str()); + return cstr; + } + + static std::string value; + +private: + static cb_type _importfile; +}; + +import_callback::cb_type import_callback::_importfile = nullptr; +std::string import_callback::value = ""; + +TEST_F(PYBIND11_TEST, reload_module) { + import_callback cb; + + import_callback::value = "value = 1\n"; + auto mod = py::module::import("reload_module"); + EXPECT_EQ(mod.attr("value").cast(), 1); + + import_callback::value = "value = 2\n"; + mod.reload(); + EXPECT_EQ(mod.attr("value").cast(), 2); + + import_callback::value = "raise ValueError()"; + // Reload in Python raises a ValueError + py::exec( + "import importlib\nimport reload_module\ntry:\n importlib.reload(reload_module)\nexcept ValueError:\n pass"); + + // Reload in C++ raises a ValueError + try { + mod.reload(); + } catch(py::error_already_set& e) { + if(e.match(tp_ValueError)) { return; } + std::rethrow_exception(std::current_exception()); + } +} + } // namespace