diff --git a/include/pocketpy/common/str.h b/include/pocketpy/common/str.h index 5ccaf5f7..fb6a52db 100644 --- a/include/pocketpy/common/str.h +++ b/include/pocketpy/common/str.h @@ -56,6 +56,7 @@ c11_sv c11_sv__slice2(c11_sv sv, int start, int stop); c11_sv c11_sv__strip(c11_sv sv, c11_sv chars, bool left, bool right); int c11_sv__index(c11_sv self, char c); int c11_sv__rindex(c11_sv self, char c); +c11_sv c11_sv__filename(c11_sv self); int c11_sv__index2(c11_sv self, c11_sv sub, int start); int c11_sv__count(c11_sv self, c11_sv sub); bool c11_sv__startswith(c11_sv self, c11_sv prefix); diff --git a/src/common/str.c b/src/common/str.c index bdd852a5..5a336664 100644 --- a/src/common/str.c +++ b/src/common/str.c @@ -147,6 +147,14 @@ int c11_sv__rindex(c11_sv self, char c) { return -1; } +c11_sv c11_sv__filename(c11_sv self) { + int sep_index_1 = c11_sv__rindex(self, '/'); + int sep_index_2 = c11_sv__rindex(self, '\\'); + int sep_index = c11__max(sep_index_1, sep_index_2); + if(sep_index == -1) return self; + return c11_sv__slice(self, sep_index + 1); +} + int c11_sv__index2(c11_sv self, c11_sv sub, int start) { if(sub.size == 0) return start; int max_end = self.size - sub.size; diff --git a/src/public/ModuleSystem.c b/src/public/ModuleSystem.c index c641e68c..4c362b59 100644 --- a/src/public/ModuleSystem.c +++ b/src/public/ModuleSystem.c @@ -78,8 +78,10 @@ int py_import(const char* path_cstr) { while(dot_count < path.size && path.data[dot_count] == '.') dot_count++; - c11_sv top_filename = c11_string__sv(vm->top_frame->co->src->filename); - int is_init = c11_sv__endswith(top_filename, (c11_sv){"__init__.py", 11}); + // */__init__.py[c] + c11_sv top_filepath = c11_string__sv(vm->top_frame->co->src->filename); + c11_sv top_filename = c11_sv__filename(top_filepath); + int is_init = c11__sveq2(top_filename, "__init__.py") || c11__sveq2(top_filename, "__init__.pyc"); py_ModuleInfo* mi = py_touserdata(vm->top_frame->module); c11_sv package_sv = c11_string__sv(mi->path);