diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index 6e89513b..f33eeafa 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -115,6 +115,7 @@ py_Type pk_range__register(); py_Type pk_range_iterator__register(); py_Type pk_BaseException__register(); py_Type pk_Exception__register(); +py_Type pk_StopIteration__register(); py_Type pk_super__register(); py_Type pk_property__register(); py_Type pk_staticmethod__register(); diff --git a/include/pocketpy/xmacros/opcodes.h b/include/pocketpy/xmacros/opcodes.h index be3fe906..5395832d 100644 --- a/include/pocketpy/xmacros/opcodes.h +++ b/include/pocketpy/xmacros/opcodes.h @@ -83,6 +83,7 @@ OPCODE(UNARY_INVERT) /**************************/ OPCODE(GET_ITER) OPCODE(FOR_ITER) +OPCODE(FOR_ITER_YIELD_VALUE) /**************************/ OPCODE(IMPORT_PATH) OPCODE(POP_IMPORT_STAR) diff --git a/src/compiler/compiler.c b/src/compiler/compiler.c index e5474dc3..35566882 100644 --- a/src/compiler/compiler.c +++ b/src/compiler/compiler.c @@ -1494,19 +1494,12 @@ static Error* pop_context(Compiler* self) { int codes_length = func->code.codes.length; for(int i = 0; i < codes_length; i++) { - if(codes[i].op == OP_YIELD_VALUE) { + if(codes[i].op == OP_YIELD_VALUE || codes[i].op == OP_FOR_ITER_YIELD_VALUE) { func->type = FuncType_GENERATOR; - for(int j = 0; j < codes_length; j++) { - if(codes[j].op == OP_RETURN_VALUE && codes[j].arg == BC_NOARG) { - Error* err = - SyntaxError(self, "'return' with argument inside generator function"); - err->lineno = c11__at(BytecodeEx, &func->code.codes_ex, j)->lineno; - return err; - } - } break; } } + if(func->type == FuncType_UNSET) { bool is_simple = true; if(func->kwargs.length > 0) is_simple = false; @@ -2035,6 +2028,20 @@ static Error* compile_for_loop(Compiler* self) { return NULL; } +static Error* compile_yield_from(Compiler* self, int kw_line) { + Error* err; + if(self->contexts.length <= 1) return SyntaxError(self, "'yield from' outside function"); + check(EXPR_TUPLE(self)); + Ctx__s_emit_top(ctx()); + Ctx__emit_(ctx(), OP_GET_ITER, BC_NOARG, kw_line); + Ctx__enter_block(ctx(), CodeBlockType_FOR_LOOP); + Ctx__emit_(ctx(), OP_FOR_ITER_YIELD_VALUE, ctx()->curr_iblock, kw_line); + Ctx__emit_(ctx(), OP_LOOP_CONTINUE, Ctx__get_loop(ctx()), kw_line); + Ctx__exit_block(ctx()); + // StopIteration.value will be pushed onto the stack + return NULL; +} + Error* try_compile_assignment(Compiler* self, bool* is_assign) { Error* err; switch(curr()->type) { @@ -2074,15 +2081,24 @@ Error* try_compile_assignment(Compiler* self, bool* is_assign) { return NULL; } case TK_ASSIGN: { + consume(TK_ASSIGN); int n = 0; - while(match(TK_ASSIGN)) { - check(EXPR_TUPLE(self)); - n += 1; + + if(match(TK_YIELD_FROM)) { + check(compile_yield_from(self, prev()->line)); + n = 1; + } else { + do { + check(EXPR_TUPLE(self)); + n += 1; + } while(match(TK_ASSIGN)); + + // stack size is n+1 + Ctx__s_emit_top(ctx()); + for(int j = 1; j < n; j++) + Ctx__emit_(ctx(), OP_DUP_TOP, BC_NOARG, BC_KEEPLINE); } - // stack size is n+1 - Ctx__s_emit_top(ctx()); - for(int j = 1; j < n; j++) - Ctx__emit_(ctx(), OP_DUP_TOP, BC_NOARG, BC_KEEPLINE); + for(int j = 0; j < n; j++) { if(Ctx__s_top(ctx())->vt->is_starred) return SyntaxError(self, "can't use starred expression here"); @@ -2488,16 +2504,8 @@ static Error* compile_stmt(Compiler* self) { consume_end_stmt(); break; case TK_YIELD_FROM: - if(self->contexts.length <= 1) - return SyntaxError(self, "'yield from' outside function"); - check(EXPR_TUPLE(self)); - Ctx__s_emit_top(ctx()); - Ctx__emit_(ctx(), OP_GET_ITER, BC_NOARG, kw_line); - Ctx__enter_block(ctx(), CodeBlockType_FOR_LOOP); - Ctx__emit_(ctx(), OP_FOR_ITER, ctx()->curr_iblock, kw_line); - Ctx__emit_(ctx(), OP_YIELD_VALUE, BC_NOARG, kw_line); - Ctx__emit_(ctx(), OP_LOOP_CONTINUE, Ctx__get_loop(ctx()), kw_line); - Ctx__exit_block(ctx()); + check(compile_yield_from(self, kw_line)); + Ctx__emit_(ctx(), OP_POP_TOP, BC_NOARG, kw_line); consume_end_stmt(); break; case TK_RETURN: diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 8a8df2c2..e75d3039 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -781,10 +781,25 @@ FrameResult VM__run_top_frame(VM* self) { PUSH(py_retval()); DISPATCH(); } else { + assert(self->last_retval.type == tp_StopIteration); int target = Frame__prepare_loop_break(frame, &self->stack); DISPATCH_JUMP_ABSOLUTE(target); } } + case OP_FOR_ITER_YIELD_VALUE: { + int res = py_next(TOP()); + if(res == -1) goto __ERROR; + if(res) { + return RES_YIELD; + } else { + assert(self->last_retval.type == tp_StopIteration); + py_ObjectRef value = py_getslot(&self->last_retval, 0); + int target = Frame__prepare_loop_break(frame, &self->stack); + if(py_isnil(value)) value = py_None(); + PUSH(value); + DISPATCH_JUMP_ABSOLUTE(target); + } + } //////// case OP_IMPORT_PATH: { py_Ref path_object = c11__at(py_TValue, &frame->co->consts, byte.arg); diff --git a/src/interpreter/generator.c b/src/interpreter/generator.c index 39728d15..1ebdc099 100644 --- a/src/interpreter/generator.c +++ b/src/interpreter/generator.c @@ -67,7 +67,10 @@ static bool generator__next__(int argc, py_Ref argv) { } else { assert(res == RES_RETURN); ud->state = 2; - return StopIteration(); + // raise StopIteration() + bool ok = py_tpcall(tp_StopIteration, 1, py_retval()); + if(!ok) return false; + return py_raise(py_retval()); } } diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index 14d539d8..3614a9b8 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -147,7 +147,10 @@ void VM__ctor(VM* self) { INJECT_BUILTIN_EXC(SystemExit, tp_BaseException); INJECT_BUILTIN_EXC(KeyboardInterrupt, tp_BaseException); - INJECT_BUILTIN_EXC(StopIteration, tp_Exception); + // INJECT_BUILTIN_EXC(StopIteration, tp_Exception); + validate(tp_StopIteration, pk_StopIteration__register()); + py_setdict(&self->builtins, py_name("StopIteration"), py_tpobject(tp_StopIteration)); + INJECT_BUILTIN_EXC(SyntaxError, tp_Exception); INJECT_BUILTIN_EXC(StackOverflowError, tp_Exception); INJECT_BUILTIN_EXC(IOError, tp_Exception); diff --git a/src/public/internal.c b/src/public/internal.c index 64bae2b4..59584e97 100644 --- a/src/public/internal.c +++ b/src/public/internal.c @@ -255,4 +255,8 @@ bool pk_callmagic(py_Name name, int argc, py_Ref argv) { return py_call(tmp, argc, argv); } -bool StopIteration() { return py_exception(tp_StopIteration, ""); } \ No newline at end of file +bool StopIteration() { + bool ok = py_tpcall(tp_StopIteration, 0, NULL); + if(!ok) return false; + return py_raise(py_retval()); +} diff --git a/src/public/modules.c b/src/public/modules.c index a1c40743..d28d4432 100644 --- a/src/public/modules.c +++ b/src/public/modules.c @@ -247,7 +247,8 @@ static bool builtins_next(int argc, py_Ref argv) { int res = py_next(argv); if(res == -1) return false; if(res) return true; - return py_exception(tp_StopIteration, ""); + // StopIteration stored in py_retval() + return py_raise(py_retval()); } static bool builtins_hash(int argc, py_Ref argv) { diff --git a/src/public/py_exception.c b/src/public/py_exception.c index 784085d2..21fe7560 100644 --- a/src/public/py_exception.c +++ b/src/public/py_exception.c @@ -96,6 +96,17 @@ static bool BaseException_args(int argc, py_Ref argv){ return true; } +static bool StopIteration_value(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + py_Ref arg = py_getslot(argv, 0); + if(py_isnil(arg)) { + py_newnone(py_retval()); + }else{ + py_assign(py_retval(), arg); + } + return true; +} + py_Type pk_BaseException__register() { py_Type type = pk_newtype("BaseException", tp_object, NULL, BaseException__dtor, false, false); @@ -112,6 +123,12 @@ py_Type pk_Exception__register() { return type; } +py_Type pk_StopIteration__register() { + py_Type type = pk_newtype("StopIteration", tp_Exception, NULL, NULL, false, false); + py_bindproperty(type, "value", StopIteration_value, NULL); + return type; +} + ////////////////////////////////////////////////// bool py_checkexc(bool ignore_handled) { VM* vm = pk_current_vm; @@ -134,13 +151,10 @@ bool py_matchexc(py_Type type) { void py_clearexc(py_StackRef p0) { VM* vm = pk_current_vm; - vm->last_retval = *py_NIL(); vm->curr_exception = *py_NIL(); vm->is_curr_exc_handled = false; - /* Don't clear this, because StopIteration() may corrupt the class defination */ // vm->__curr_class = NULL; - vm->__curr_function = NULL; if(p0) vm->stack.sp = p0; } diff --git a/src/public/py_ops.c b/src/public/py_ops.c index 54eb812a..ace7b388 100644 --- a/src/public/py_ops.c +++ b/src/public/py_ops.c @@ -77,6 +77,7 @@ int py_next(py_Ref val) { } if(py_call(tmp, 1, val)) return 1; if(vm->curr_exception.type == tp_StopIteration) { + vm->last_retval = vm->curr_exception; py_clearexc(NULL); return 0; } diff --git a/tests/51_yield.py b/tests/51_yield.py index b570b905..3e534880 100644 --- a/tests/51_yield.py +++ b/tests/51_yield.py @@ -99,16 +99,27 @@ def f(): assert list(f()) == [1, 2] -src = ''' def g(): yield 1 yield 2 return 3 yield 4 -''' + +assert StopIteration().value == None +assert StopIteration(3).value == 3 try: - exec(src) + iter = g() + assert next(iter) == 1 + assert next(iter) == 2 + next(iter) # raises StopIteration + print('UNREACHABLE!!') exit(1) -except SyntaxError: - pass +except StopIteration as e: + assert e.value == 3 + +def f(): + a = yield from g() + yield a + +assert list(f()) == [1, 2, 3] \ No newline at end of file