fix generator

This commit is contained in:
blueloveTH 2024-08-08 11:09:30 +08:00
parent f481337f98
commit cd3c28fdd8
12 changed files with 50 additions and 26 deletions

View File

@ -36,7 +36,7 @@ typedef struct Frame {
const Bytecode* ip;
const CodeObject* co;
py_GlobalRef module;
py_StackRef function; // a function object or NULL (global scope)
bool has_function; // is p0 a function?
py_StackRef p0; // unwinding base
py_StackRef locals; // locals base
UnwindTarget* uw_list;
@ -44,7 +44,7 @@ typedef struct Frame {
Frame* Frame__new(const CodeObject* co,
py_GlobalRef module,
py_StackRef function,
bool has_function,
py_StackRef p0,
py_StackRef locals);
void Frame__delete(Frame* self);

View File

@ -1,10 +1,13 @@
#pragma once
#include "pocketpy/interpreter/frame.h"
#include "pocketpy/pocketpy.h"
typedef struct Generator{
Frame* frame;
int state;
} Generator;
void pk_newgenerator(py_Ref out, Frame* frame);
void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length);
void Generator__dtor(Generator* ud);

View File

@ -19,6 +19,8 @@ typedef int16_t py_Type;
typedef int64_t py_i64;
typedef double py_f64;
typedef void (*py_Dtor)(void*);
#define PY_RAISE // mark a function that can raise an exception
typedef struct c11_sv {
@ -137,7 +139,7 @@ c11_sv py_name2sv(py_Name);
/// @param base base type.
/// @param module module where the type is defined. Use `NULL` for built-in types.
/// @param dtor destructor function. Use `NULL` if not needed.
py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, void (*dtor)(void*));
py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, py_Dtor dtor);
/// Create a new object.
/// @param out output reference.

View File

@ -285,7 +285,7 @@ FrameResult VM__run_top_frame(VM* self) {
case OP_STORE_FAST: frame->locals[byte.arg] = POPX(); DISPATCH();
case OP_STORE_NAME: {
py_Name name = byte.arg;
if(frame->function) {
if(frame->has_function) {
py_Ref slot = Frame__f_locals_try_get(frame, name);
if(slot != NULL) {
*slot = *TOP(); // store in locals if possible
@ -346,7 +346,7 @@ FrameResult VM__run_top_frame(VM* self) {
}
case OP_DELETE_NAME: {
py_Name name = byte.arg;
if(frame->function) {
if(frame->has_function) {
py_TValue* slot = Frame__f_locals_try_get(frame, name);
if(slot) {
py_newnil(slot);
@ -977,7 +977,7 @@ FrameResult VM__run_top_frame(VM* self) {
py_BaseException__stpush(&self->curr_exception,
frame->co->src,
lineno < 0 ? Frame__lineno(frame) : lineno,
frame->function ? frame->co->name->data : NULL);
frame->has_function ? frame->co->name->data : NULL);
int target = Frame__prepare_jump_exception_handler(frame, &self->stack);
if(target >= 0) {

View File

@ -37,7 +37,7 @@ void UnwindTarget__delete(UnwindTarget* self) { free(self); }
Frame* Frame__new(const CodeObject* co,
py_GlobalRef module,
py_StackRef function,
bool has_function,
py_StackRef p0,
py_StackRef locals) {
static_assert(sizeof(Frame) <= kPoolFrameBlockSize, "!(sizeof(Frame) <= kPoolFrameBlockSize)");
@ -46,7 +46,7 @@ Frame* Frame__new(const CodeObject* co,
self->ip = (Bytecode*)co->codes.data - 1;
self->co = co;
self->module = module;
self->function = function;
self->has_function = has_function;
self->p0 = p0;
self->locals = locals;
self->uw_list = NULL;
@ -131,8 +131,8 @@ void Frame__set_unwind_target(Frame* self, py_TValue* sp) {
}
py_TValue* Frame__f_closure_try_get(Frame* self, py_Name name) {
if(self->function == NULL) return NULL;
Function* ud = py_touserdata(self->function);
if(!self->has_function) return NULL;
Function* ud = py_touserdata(self->p0);
if(ud->closure == NULL) return NULL;
return NameDict__try_get(ud->closure, name);
}

View File

@ -5,11 +5,19 @@
#include "pocketpy/pocketpy.h"
#include <stdbool.h>
void pk_newgenerator(py_Ref out, Frame* frame) {
void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length) {
Generator* ud = py_newobject(out, tp_generator, 1, sizeof(Generator));
ud->frame = frame;
ud->state = 0;
py_newlist(py_getslot(out, 0));
py_Ref tmp = py_getslot(out, 0);
py_newlist(tmp);
for(int i = 0; i < backup_length; i++) {
py_list_append(tmp, &backup[i]);
}
}
void Generator__dtor(Generator* ud) {
if(ud->frame) { Frame__delete(ud->frame); }
}
static bool generator__next__(int argc, py_Ref argv) {
@ -33,6 +41,7 @@ static bool generator__next__(int argc, py_Ref argv) {
// push frame
VM__push_frame(vm, ud->frame);
ud->frame = NULL;
FrameResult res = VM__run_top_frame(vm);
@ -51,17 +60,19 @@ static bool generator__next__(int argc, py_Ref argv) {
for(py_StackRef p = ud->frame->p0; p != vm->stack.sp; p++) {
py_list_append(backup, p);
}
vm->stack.sp = ud->frame->p0;
vm->top_frame = vm->top_frame->f_back;
ud->state = 1;
return true;
} else {
assert(res == RES_RETURN);
ud->state = 2;
return StopIteration();
}
}
py_Type pk_generator__register() {
py_Type type = pk_newtype("generator", tp_object, NULL, NULL, false, true);
py_Type type = pk_newtype("generator", tp_object, NULL, (py_Dtor)Generator__dtor, false, true);
py_bindmagic(type, __iter__, pk_wrapper__self);
py_bindmagic(type, __next__, generator__next__);

View File

@ -427,7 +427,7 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
memcpy(argv, self->__vectorcall_buffer, co->nlocals * sizeof(py_TValue));
// submit the call
if(!fn->cfunc) {
VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv));
VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv));
return opcall ? RES_CALL : VM__run_top_frame(self);
} else {
bool ok = fn->cfunc(co->nlocals, argv);
@ -451,13 +451,13 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
// initialize local variables to py_NIL
memset(p1, 0, (char*)self->stack.sp - (char*)p1);
// submit the call
VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv));
VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv));
return opcall ? RES_CALL : VM__run_top_frame(self);
case FuncType_GENERATOR: {
bool ok = prepare_py_call(self->__vectorcall_buffer, argv, p1, kwargc, fn->decl);
if(!ok) return RES_ERROR;
Frame* frame = Frame__new(co, &fn->module, p0, p0, argv);
pk_newgenerator(py_retval(), frame);
Frame* frame = Frame__new(co, &fn->module, false, p0, argv);
pk_newgenerator(py_retval(), frame, self->__vectorcall_buffer, co->nlocals);
self->stack.sp = p0;
return RES_RETURN;
}
@ -592,7 +592,7 @@ void ManagedHeap__mark(ManagedHeap* self) {
}
void pk_print_stack(VM* self, Frame* frame, Bytecode byte) {
// return;
return;
if(frame == NULL || py_isnil(&self->main)) return;
py_TValue* sp = self->stack.sp;

View File

@ -93,7 +93,7 @@ bool py_exec(const char* source, const char* filename, enum py_CompileMode mode,
if(!module) module = &vm->main;
Frame* frame = Frame__new(&co, module, NULL, vm->stack.sp, vm->stack.sp);
Frame* frame = Frame__new(&co, module, false, vm->stack.sp, vm->stack.sp);
VM__push_frame(vm, frame);
FrameResult res = VM__run_top_frame(vm);
CodeObject__dtor(&co);

View File

@ -468,8 +468,8 @@ static bool super__new__(int argc, py_Ref argv) {
py_Ref self_arg = NULL;
if(argc == 1) {
// super()
if(frame->function) {
Function* func = py_touserdata(frame->function);
if(frame->has_function) {
Function* func = py_touserdata(frame->p0);
*class_arg = *(py_Type*)PyObject__userdata(func->clazz);
if(frame->co->nlocals > 0) self_arg = &frame->locals[0];
}

View File

@ -81,7 +81,9 @@ int py_next(py_Ref val) {
py_clearexc(p0);
vm->is_stopiteration = true;
}
return vm->is_stopiteration ? 0 : -1;
int retval = vm->is_stopiteration ? 0 : -1;
vm->is_stopiteration = false;
return retval;
}
bool py_getattr(py_Ref self, py_Name name) {

View File

@ -60,8 +60,8 @@ void py_setslot(py_Ref self, int i, py_Ref val) {
py_StackRef py_inspect_currentfunction(){
Frame* frame = pk_current_vm->top_frame;
if(!frame) return NULL;
return frame->function;
if(!frame || !frame->has_function) return NULL;
return frame->p0;
}
void py_assign(py_Ref dst, py_Ref src) { *dst = *src; }

View File

@ -6,6 +6,12 @@ a = g()
assert next(a) == 1
assert next(a) == 2
try:
next(a)
exit(1)
except StopIteration:
pass
def f(n):
for i in range(n):
yield i
@ -50,7 +56,7 @@ assert a == [1, 2, 3]
def f():
for i in range(5):
yield str(i)
assert '|'.join(f()) == '0|1|2|3|4'
assert '|'.join(list(f())) == '0|1|2|3|4'
def f(n):