From 57cd40da6f9940d4646c76c96cf47fed8159bf50 Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 17 Jun 2025 23:22:13 +0800 Subject: [PATCH] add compile time func --- include/pocketpy/interpreter/vm.h | 1 + include/pocketpy/pocketpy.h | 5 ++ src/compiler/compiler.c | 82 ++++++++++++++++++++++++++++++- src/interpreter/vm.c | 8 +++ src/public/values.c | 20 ++++++-- src2/main.c | 2 +- 6 files changed, 113 insertions(+), 5 deletions(-) diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index 73650843..686ee963 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -51,6 +51,7 @@ typedef struct VM { void* ctx; // user-defined context BinTree cached_names; + NameDict compile_time_funcs; py_StackRef curr_class; py_StackRef curr_decl_based_function; diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 8ff06569..8d3468d3 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -129,6 +129,11 @@ PK_API void py_watchdog_begin(py_i64 timeout); /// Reset the watchdog. PK_API void py_watchdog_end(); +/// Bind a compile-time function via "decl-based" style. +PK_API void py_compiletime_bind(const char* sig, py_CFunction f); +/// Find a compile-time function by name. +PK_API py_ItemRef py_compiletime_getfunc(py_Name name); + /// Get the current source location of the frame. PK_API const char* py_Frame_sourceloc(py_Frame* frame, int* lineno); /// Python equivalent to `globals()` with respect to the given frame. diff --git a/src/compiler/compiler.c b/src/compiler/compiler.c index c557630b..4a02bad3 100644 --- a/src/compiler/compiler.c +++ b/src/compiler/compiler.c @@ -368,6 +368,25 @@ Literal0Expr* Literal0Expr__new(int line, TokenIndex token) { return self; } +typedef struct LoadConstExpr { + EXPR_COMMON_HEADER + int index; +} LoadConstExpr; + +void LoadConstExpr__emit_(Expr* self_, Ctx* ctx) { + LoadConstExpr* self = (LoadConstExpr*)self_; + Ctx__emit_(ctx, OP_LOAD_CONST, self->index, self->line); +} + +LoadConstExpr* LoadConstExpr__new(int line, int index) { + const static ExprVt Vt = {.emit_ = LoadConstExpr__emit_}; + LoadConstExpr* self = PK_MALLOC(sizeof(LoadConstExpr)); + self->vt = &Vt; + self->line = line; + self->index = index; + return self; +} + typedef struct SliceExpr { EXPR_COMMON_HEADER Expr* start; @@ -1864,9 +1883,70 @@ static Error* exprMap(Compiler* self) { return NULL; } +static Error* read_literal(Compiler* self, py_Ref out); + +static Error* exprCompileTimeCall(Compiler* self, py_ItemRef func, int line) { + Error* err; + py_push(func); + py_pushnil(); + + uint16_t argc = 0; + uint16_t kwargc = 0; + // copied from `exprCall` + do { + match_newlines(); + if(curr()->type == TK_RPAREN) break; + if(curr()->type == TK_ID && next()->type == TK_ASSIGN) { + consume(TK_ID); + py_Name key = py_namev(Token__sv(prev())); + consume(TK_ASSIGN); + // k=v + py_pushname(key); + check(read_literal(self, py_pushtmp())); + kwargc += 1; + } else { + if(kwargc > 0) { + return SyntaxError(self, "positional argument follows keyword argument"); + } + check(read_literal(self, py_pushtmp())); + argc += 1; + } + match_newlines(); + } while(match(TK_COMMA)); + consume(TK_RPAREN); + + bool ok = py_vectorcall(argc, kwargc); + if(!ok) { + char* msg = py_formatexc(); + err = SyntaxError(self, "compile-time call error:\n%s", msg); + PK_FREE(msg); + return err; + } + + // TODO: optimize string dedup + int index = Ctx__add_const(ctx(), py_retval()); + Ctx__s_push(ctx(), (Expr*)LoadConstExpr__new(line, index)); + return NULL; +} + static Error* exprCall(Compiler* self) { Error* err; - CallExpr* e = CallExpr__new(prev()->line, Ctx__s_popx(ctx())); + Expr* callable = Ctx__s_popx(ctx()); + int line = prev()->line; + if(callable->vt->is_name) { + NameExpr* ne = (NameExpr*)callable; + if(ne->scope == NAME_GLOBAL) { + py_ItemRef func = py_compiletime_getfunc(ne->name); + if(func != NULL) { + py_StackRef p0 = py_peek(0); + err = exprCompileTimeCall(self, func, line); + if(err != NULL) py_clearexc(p0); + return err; + } + } + } + + CallExpr* e = CallExpr__new(line, callable); Ctx__s_push(ctx(), (Expr*)e); // push onto the stack in advance do { match_newlines(); diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index 8f8b038c..0865b7b3 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -116,6 +116,7 @@ void VM__ctor(VM* self) { .need_free_key = false, }; BinTree__ctor(&self->cached_names, NULL, py_NIL(), &cached_names_config); + NameDict__ctor(&self->compile_time_funcs, PK_TYPE_ATTR_LOAD_FACTOR); /* Init Builtin Types */ // 0: unused @@ -294,6 +295,7 @@ void VM__dtor(VM* self) { FixedMemoryPool__dtor(&self->pool_frame); ValueStack__dtor(&self->stack); BinTree__dtor(&self->cached_names); + NameDict__dtor(&self->compile_time_funcs); } void VM__push_frame(VM* self, py_Frame* frame) { @@ -673,6 +675,12 @@ void ManagedHeap__mark(ManagedHeap* self) { BinTree__apply_mark(&vm->modules, p_stack); // mark cached names BinTree__apply_mark(&vm->cached_names, p_stack); + // mark compile time functions + for(int i = 0; i < vm->compile_time_funcs.capacity; i++) { + NameDict_KV* kv = &vm->compile_time_funcs.items[i]; + if(kv->key == NULL) continue; + pk__mark_value(&kv->value); + } // mark types int types_length = vm->types.length; // 0-th type is placeholder diff --git a/src/public/values.c b/src/public/values.c index a05abde6..38ac4038 100644 --- a/src/public/values.c +++ b/src/public/values.c @@ -89,9 +89,23 @@ void py_bindmagic(py_Type type, py_Name name, py_CFunction f) { } void py_bind(py_Ref obj, const char* sig, py_CFunction f) { - py_TValue tmp; - py_Name name = py_newfunction(&tmp, sig, f, NULL, 0); - py_setdict(obj, name, &tmp); + py_Ref tmp = py_pushtmp(); + py_Name name = py_newfunction(tmp, sig, f, NULL, 0); + py_setdict(obj, name, tmp); + py_pop(); +} + +void py_compiletime_bind(const char* sig, py_CFunction f) { + py_Ref tmp = py_pushtmp(); + py_Name name = py_newfunction(tmp, sig, f, NULL, 0); + NameDict__set(&pk_current_vm->compile_time_funcs, name, tmp); + py_pop(); +} + +PK_API py_ItemRef py_compiletime_getfunc(py_Name name) { + NameDict* d = &pk_current_vm->compile_time_funcs; + if(d->length == 0) return NULL; + return NameDict__try_get(d, name); } py_Name py_newfunction(py_OutRef out, diff --git a/src2/main.c b/src2/main.c index 357ebe2e..109d15c6 100644 --- a/src2/main.c +++ b/src2/main.c @@ -53,7 +53,7 @@ int main(int argc, char** argv) { py_initialize(); py_sys_setargv(argc, argv); - assert(!profile); // not implemented yet + assert(!profile); // not implemented yet // if(profile) py_sys_settrace(LineProfiler__tracefunc, true); if(filename == NULL) {