From cd3c28fdd87a675303aa6ae2f34e435621d3104f Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Thu, 8 Aug 2024 11:09:30 +0800 Subject: [PATCH] fix generator --- include/pocketpy/interpreter/frame.h | 4 ++-- include/pocketpy/interpreter/generator.h | 5 ++++- include/pocketpy/pocketpy.h | 4 +++- src/interpreter/ceval.c | 6 +++--- src/interpreter/frame.c | 8 ++++---- src/interpreter/generator.c | 17 ++++++++++++++--- src/interpreter/vm.c | 10 +++++----- src/public/internal.c | 2 +- src/public/modules.c | 4 ++-- src/public/py_ops.c | 4 +++- src/public/stack_ops.c | 4 ++-- tests/51_yield.py | 8 +++++++- 12 files changed, 50 insertions(+), 26 deletions(-) diff --git a/include/pocketpy/interpreter/frame.h b/include/pocketpy/interpreter/frame.h index ff39c622..deb3f895 100644 --- a/include/pocketpy/interpreter/frame.h +++ b/include/pocketpy/interpreter/frame.h @@ -36,7 +36,7 @@ typedef struct Frame { const Bytecode* ip; const CodeObject* co; py_GlobalRef module; - py_StackRef function; // a function object or NULL (global scope) + bool has_function; // is p0 a function? py_StackRef p0; // unwinding base py_StackRef locals; // locals base UnwindTarget* uw_list; @@ -44,7 +44,7 @@ typedef struct Frame { Frame* Frame__new(const CodeObject* co, py_GlobalRef module, - py_StackRef function, + bool has_function, py_StackRef p0, py_StackRef locals); void Frame__delete(Frame* self); diff --git a/include/pocketpy/interpreter/generator.h b/include/pocketpy/interpreter/generator.h index e37f623d..6adb5986 100644 --- a/include/pocketpy/interpreter/generator.h +++ b/include/pocketpy/interpreter/generator.h @@ -1,10 +1,13 @@ #pragma once #include "pocketpy/interpreter/frame.h" +#include "pocketpy/pocketpy.h" typedef struct Generator{ Frame* frame; int state; } Generator; -void pk_newgenerator(py_Ref out, Frame* frame); \ No newline at end of file +void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length); + +void Generator__dtor(Generator* ud); \ No newline at end of file diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index 3f2363d3..d1c42e86 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -19,6 +19,8 @@ typedef int16_t py_Type; typedef int64_t py_i64; typedef double py_f64; +typedef void (*py_Dtor)(void*); + #define PY_RAISE // mark a function that can raise an exception typedef struct c11_sv { @@ -137,7 +139,7 @@ c11_sv py_name2sv(py_Name); /// @param base base type. /// @param module module where the type is defined. Use `NULL` for built-in types. /// @param dtor destructor function. Use `NULL` if not needed. -py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, void (*dtor)(void*)); +py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, py_Dtor dtor); /// Create a new object. /// @param out output reference. diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 0a32a600..3dc4e1f0 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -285,7 +285,7 @@ FrameResult VM__run_top_frame(VM* self) { case OP_STORE_FAST: frame->locals[byte.arg] = POPX(); DISPATCH(); case OP_STORE_NAME: { py_Name name = byte.arg; - if(frame->function) { + if(frame->has_function) { py_Ref slot = Frame__f_locals_try_get(frame, name); if(slot != NULL) { *slot = *TOP(); // store in locals if possible @@ -346,7 +346,7 @@ FrameResult VM__run_top_frame(VM* self) { } case OP_DELETE_NAME: { py_Name name = byte.arg; - if(frame->function) { + if(frame->has_function) { py_TValue* slot = Frame__f_locals_try_get(frame, name); if(slot) { py_newnil(slot); @@ -977,7 +977,7 @@ FrameResult VM__run_top_frame(VM* self) { py_BaseException__stpush(&self->curr_exception, frame->co->src, lineno < 0 ? Frame__lineno(frame) : lineno, - frame->function ? frame->co->name->data : NULL); + frame->has_function ? frame->co->name->data : NULL); int target = Frame__prepare_jump_exception_handler(frame, &self->stack); if(target >= 0) { diff --git a/src/interpreter/frame.c b/src/interpreter/frame.c index 401b98f8..73eab4a4 100644 --- a/src/interpreter/frame.c +++ b/src/interpreter/frame.c @@ -37,7 +37,7 @@ void UnwindTarget__delete(UnwindTarget* self) { free(self); } Frame* Frame__new(const CodeObject* co, py_GlobalRef module, - py_StackRef function, + bool has_function, py_StackRef p0, py_StackRef locals) { static_assert(sizeof(Frame) <= kPoolFrameBlockSize, "!(sizeof(Frame) <= kPoolFrameBlockSize)"); @@ -46,7 +46,7 @@ Frame* Frame__new(const CodeObject* co, self->ip = (Bytecode*)co->codes.data - 1; self->co = co; self->module = module; - self->function = function; + self->has_function = has_function; self->p0 = p0; self->locals = locals; self->uw_list = NULL; @@ -131,8 +131,8 @@ void Frame__set_unwind_target(Frame* self, py_TValue* sp) { } py_TValue* Frame__f_closure_try_get(Frame* self, py_Name name) { - if(self->function == NULL) return NULL; - Function* ud = py_touserdata(self->function); + if(!self->has_function) return NULL; + Function* ud = py_touserdata(self->p0); if(ud->closure == NULL) return NULL; return NameDict__try_get(ud->closure, name); } diff --git a/src/interpreter/generator.c b/src/interpreter/generator.c index f069ed74..5b4d5009 100644 --- a/src/interpreter/generator.c +++ b/src/interpreter/generator.c @@ -5,11 +5,19 @@ #include "pocketpy/pocketpy.h" #include -void pk_newgenerator(py_Ref out, Frame* frame) { +void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length) { Generator* ud = py_newobject(out, tp_generator, 1, sizeof(Generator)); ud->frame = frame; ud->state = 0; - py_newlist(py_getslot(out, 0)); + py_Ref tmp = py_getslot(out, 0); + py_newlist(tmp); + for(int i = 0; i < backup_length; i++) { + py_list_append(tmp, &backup[i]); + } +} + +void Generator__dtor(Generator* ud) { + if(ud->frame) { Frame__delete(ud->frame); } } static bool generator__next__(int argc, py_Ref argv) { @@ -33,6 +41,7 @@ static bool generator__next__(int argc, py_Ref argv) { // push frame VM__push_frame(vm, ud->frame); + ud->frame = NULL; FrameResult res = VM__run_top_frame(vm); @@ -51,17 +60,19 @@ static bool generator__next__(int argc, py_Ref argv) { for(py_StackRef p = ud->frame->p0; p != vm->stack.sp; p++) { py_list_append(backup, p); } + vm->stack.sp = ud->frame->p0; vm->top_frame = vm->top_frame->f_back; ud->state = 1; return true; } else { + assert(res == RES_RETURN); ud->state = 2; return StopIteration(); } } py_Type pk_generator__register() { - py_Type type = pk_newtype("generator", tp_object, NULL, NULL, false, true); + py_Type type = pk_newtype("generator", tp_object, NULL, (py_Dtor)Generator__dtor, false, true); py_bindmagic(type, __iter__, pk_wrapper__self); py_bindmagic(type, __next__, generator__next__); diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index d25bd95c..a928f5c4 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -427,7 +427,7 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall memcpy(argv, self->__vectorcall_buffer, co->nlocals * sizeof(py_TValue)); // submit the call if(!fn->cfunc) { - VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv)); + VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv)); return opcall ? RES_CALL : VM__run_top_frame(self); } else { bool ok = fn->cfunc(co->nlocals, argv); @@ -451,13 +451,13 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall // initialize local variables to py_NIL memset(p1, 0, (char*)self->stack.sp - (char*)p1); // submit the call - VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv)); + VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv)); return opcall ? RES_CALL : VM__run_top_frame(self); case FuncType_GENERATOR: { bool ok = prepare_py_call(self->__vectorcall_buffer, argv, p1, kwargc, fn->decl); if(!ok) return RES_ERROR; - Frame* frame = Frame__new(co, &fn->module, p0, p0, argv); - pk_newgenerator(py_retval(), frame); + Frame* frame = Frame__new(co, &fn->module, false, p0, argv); + pk_newgenerator(py_retval(), frame, self->__vectorcall_buffer, co->nlocals); self->stack.sp = p0; return RES_RETURN; } @@ -592,7 +592,7 @@ void ManagedHeap__mark(ManagedHeap* self) { } void pk_print_stack(VM* self, Frame* frame, Bytecode byte) { - // return; + return; if(frame == NULL || py_isnil(&self->main)) return; py_TValue* sp = self->stack.sp; diff --git a/src/public/internal.c b/src/public/internal.c index 58ae5ab0..f4c91bbc 100644 --- a/src/public/internal.c +++ b/src/public/internal.c @@ -93,7 +93,7 @@ bool py_exec(const char* source, const char* filename, enum py_CompileMode mode, if(!module) module = &vm->main; - Frame* frame = Frame__new(&co, module, NULL, vm->stack.sp, vm->stack.sp); + Frame* frame = Frame__new(&co, module, false, vm->stack.sp, vm->stack.sp); VM__push_frame(vm, frame); FrameResult res = VM__run_top_frame(vm); CodeObject__dtor(&co); diff --git a/src/public/modules.c b/src/public/modules.c index 0b3e5e4b..11bb0da5 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -468,8 +468,8 @@ static bool super__new__(int argc, py_Ref argv) { py_Ref self_arg = NULL; if(argc == 1) { // super() - if(frame->function) { - Function* func = py_touserdata(frame->function); + if(frame->has_function) { + Function* func = py_touserdata(frame->p0); *class_arg = *(py_Type*)PyObject__userdata(func->clazz); if(frame->co->nlocals > 0) self_arg = &frame->locals[0]; } diff --git a/src/public/py_ops.c b/src/public/py_ops.c index 55f99b87..ac4bfe47 100644 --- a/src/public/py_ops.c +++ b/src/public/py_ops.c @@ -81,7 +81,9 @@ int py_next(py_Ref val) { py_clearexc(p0); vm->is_stopiteration = true; } - return vm->is_stopiteration ? 0 : -1; + int retval = vm->is_stopiteration ? 0 : -1; + vm->is_stopiteration = false; + return retval; } bool py_getattr(py_Ref self, py_Name name) { diff --git a/src/public/stack_ops.c b/src/public/stack_ops.c index 73bab5ed..f68be579 100644 --- a/src/public/stack_ops.c +++ b/src/public/stack_ops.c @@ -60,8 +60,8 @@ void py_setslot(py_Ref self, int i, py_Ref val) { py_StackRef py_inspect_currentfunction(){ Frame* frame = pk_current_vm->top_frame; - if(!frame) return NULL; - return frame->function; + if(!frame || !frame->has_function) return NULL; + return frame->p0; } void py_assign(py_Ref dst, py_Ref src) { *dst = *src; } diff --git a/tests/51_yield.py b/tests/51_yield.py index 39342b40..03f54215 100644 --- a/tests/51_yield.py +++ b/tests/51_yield.py @@ -6,6 +6,12 @@ a = g() assert next(a) == 1 assert next(a) == 2 +try: + next(a) + exit(1) +except StopIteration: + pass + def f(n): for i in range(n): yield i @@ -50,7 +56,7 @@ assert a == [1, 2, 3] def f(): for i in range(5): yield str(i) -assert '|'.join(f()) == '0|1|2|3|4' +assert '|'.join(list(f())) == '0|1|2|3|4' def f(n):