diff --git a/docs/quick-start/modules.md b/docs/quick-start/modules.md index 0fa0e8af..1d982941 100644 --- a/docs/quick-start/modules.md +++ b/docs/quick-start/modules.md @@ -53,16 +53,16 @@ When you do `import` a module, the VM will try to find it in the following order 1. Search `vm->_modules`, if found, return it. 2. Search `vm->_lazy_modules`, if found, compile and execute it, then return it. -3. Search the working directory and try to load it from file system via `read_file_cwd`. +3. Search the working directory and try to load it from file system via `vm->_import_handler`. -### Filesystem hook +### Customized import handler -You can use `set_read_file_cwd` to provide a custom filesystem hook, which is used for `import` (3rd step). -The default implementation is: +You can use `vm->_import_handler` to provide a custom import handler for the 3rd step. +if both `enable_os` and `PK_ENABLE_OS` are `true`, the default `import_handler` is as follows: ```cpp -set_read_file_cwd([](const Str& name){ +inline Bytes _default_import_handler(const Str& name){ std::filesystem::path path(name.sv()); bool exists = std::filesystem::exists(path); if(!exists) return Bytes(); @@ -75,5 +75,5 @@ set_read_file_cwd([](const Str& name){ fread(buffer.data(), 1, buffer.size(), fp); fclose(fp); return Bytes(std::move(buffer)); -}); +}; ``` \ No newline at end of file diff --git a/src/io.h b/src/io.h index 3806412a..44deb7fb 100644 --- a/src/io.h +++ b/src/io.h @@ -11,7 +11,7 @@ namespace pkpy{ -inline int _ = set_read_file_cwd([](const Str& name){ +inline Bytes _default_import_handler(const Str& name){ std::filesystem::path path(name.sv()); bool exists = std::filesystem::exists(path); if(!exists) return Bytes(); @@ -24,7 +24,7 @@ inline int _ = set_read_file_cwd([](const Str& name){ fread(buffer.data(), 1, buffer.size(), fp); fclose(fp); return Bytes(std::move(buffer)); -}); +}; struct FileIO { PY_CLASS(FileIO, io, FileIO) @@ -183,6 +183,7 @@ inline void add_module_os(VM* vm){ namespace pkpy{ inline void add_module_io(void* vm){} inline void add_module_os(void* vm){} +inline Bytes _default_import_handler(const Str& name) { return Bytes(); } } // namespace pkpy #endif \ No newline at end of file diff --git a/src/pocketpy.h b/src/pocketpy.h index 24ae27a5..43a75a13 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -1458,6 +1458,7 @@ inline void VM::post_init(){ add_module_io(this); add_module_os(this); add_module_requests(this); + _import_handler = _default_import_handler; } add_module_linalg(this); diff --git a/src/vm.h b/src/vm.h index f72e2ca0..1d246222 100644 --- a/src/vm.h +++ b/src/vm.h @@ -25,10 +25,6 @@ namespace pkpy{ #define POPX() (s_data.popx()) #define STACK_VIEW(n) (s_data.view(n)) -typedef Bytes (*ReadFileCwdFunc)(const Str& name); -inline ReadFileCwdFunc _read_file_cwd = [](const Str& name) { return Bytes(); }; -inline int set_read_file_cwd(ReadFileCwdFunc func) { _read_file_cwd = func; return 0; } - #define DEF_NATIVE_2(ctype, ptype) \ template<> inline ctype py_cast(VM* vm, PyObject* obj) { \ vm->check_non_tagged_type(obj, vm->ptype); \ @@ -127,6 +123,7 @@ public: PrintFunc _stdout; PrintFunc _stderr; + Bytes (*_import_handler)(const Str& name); // for quick access Type tp_object, tp_type, tp_int, tp_float, tp_bool, tp_str; @@ -145,6 +142,7 @@ public: callstack.reserve(8); _main = nullptr; _last_exception = nullptr; + _import_handler = [](const Str& name) { return Bytes(); }; init_builtin_types(); } @@ -604,10 +602,10 @@ public: Str source; auto it = _lazy_modules.find(name); if(it == _lazy_modules.end()){ - Bytes b = _read_file_cwd(filename); + Bytes b = _import_handler(filename); if(!relative && !b){ filename = fmt(name, kPlatformSep, "__init__.py"); - b = _read_file_cwd(filename); + b = _import_handler(filename); if(b) type = 1; } if(!b) _error("ImportError", fmt("module ", name.escape(), " not found")); diff --git a/tests/30_import.py b/tests/30_import.py index 112cc69c..501b672a 100644 --- a/tests/30_import.py +++ b/tests/30_import.py @@ -1,3 +1,8 @@ +try: + import os +except ImportError: + exit(0) + import test1 assert test1.add(1, 2) == 13 \ No newline at end of file