diff --git a/src/codeobject.h b/src/codeobject.h index 3cad23a4..556eec91 100644 --- a/src/codeobject.h +++ b/src/codeobject.h @@ -51,6 +51,7 @@ struct CodeBlock { struct CodeObject { pkpy::shared_ptr src; Str name; + bool is_generator = false; CodeObject(pkpy::shared_ptr src, Str name) { this->src = src; diff --git a/src/compiler.h b/src/compiler.h index 6b7904d5..7058bc83 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -849,9 +849,16 @@ __LISTCOMP: if (!co()->_is_curr_block_loop()) SyntaxError("'continue' not properly in loop"); consume_end_stmt(); emit(OP_LOOP_CONTINUE); + } else if (match(TK("yield"))) { + if (codes.size() == 1) SyntaxError("'yield' outside function"); + co()->_rvalue = true; + EXPR_TUPLE(); + co()->_rvalue = false; + consume_end_stmt(); + co()->is_generator = true; + emit(OP_YIELD_VALUE, -1, true); } else if (match(TK("return"))) { - if (codes.size() == 1) - SyntaxError("'return' outside function"); + if (codes.size() == 1) SyntaxError("'return' outside function"); if(match_end_stmt()){ emit(OP_LOAD_NONE); }else{ diff --git a/src/iter.h b/src/iter.h index c0e3cc21..42def72a 100644 --- a/src/iter.h +++ b/src/iter.h @@ -11,11 +11,12 @@ public: this->current = r.start; } - bool has_next(){ + inline bool _has_next(){ return r.step > 0 ? current < r.stop : current > r.stop; } PyVar next(){ + if(!_has_next()) return nullptr; current += r.step; return vm->PyInt(current-r.step); } @@ -27,8 +28,10 @@ class ArrayIter : public BaseIter { const T* p; public: ArrayIter(VM* vm, PyVar _ref) : BaseIter(vm, _ref) { p = &OBJ_GET(T, _ref);} - bool has_next(){ return index < p->size(); } - PyVar next(){ return p->operator[](index++); } + PyVar next(){ + if(index == p->size()) return nullptr; + return p->operator[](index++); + } }; class StringIter : public BaseIter { @@ -39,6 +42,31 @@ public: str = OBJ_GET(Str, _ref); } - bool has_next(){ return index < str.u8_length(); } - PyVar next() { return vm->PyStr(str.u8_getitem(index++)); } + PyVar next() { + if(index == str.u8_length()) return nullptr; + return vm->PyStr(str.u8_getitem(index++)); + } }; + +class Generator: public BaseIter { + std::unique_ptr frame; + int state; // 0,1,2 +public: + Generator(VM* vm, std::unique_ptr&& frame) + : BaseIter(vm, nullptr), frame(std::move(frame)), state(0) {} + + PyVar next() { + if(state == 2) return nullptr; + vm->callstack.push(std::move(frame)); + PyVar ret = vm->_exec(); + if(ret == vm->_py_op_yield){ + frame = std::move(vm->callstack.top()); + vm->callstack.pop(); + state = 1; + return frame->pop_value(vm); + }else{ + state = 2; + return nullptr; + } + } +}; \ No newline at end of file diff --git a/src/obj.h b/src/obj.h index 703b1277..8b613b8c 100644 --- a/src/obj.h +++ b/src/obj.h @@ -68,7 +68,6 @@ protected: PyVar _ref; // keep a reference to the object so it will not be deleted while iterating public: virtual PyVar next() = 0; - virtual bool has_next() = 0; PyVarRef var; BaseIter(VM* vm, PyVar _ref) : vm(vm), _ref(_ref) {} virtual ~BaseIter() = default; diff --git a/src/opcodes.h b/src/opcodes.h index 73315332..fb1d891b 100644 --- a/src/opcodes.h +++ b/src/opcodes.h @@ -68,6 +68,7 @@ OPCODE(DELETE_REF) OPCODE(TRY_BLOCK_ENTER) OPCODE(TRY_BLOCK_EXIT) +OPCODE(YIELD_VALUE) //OPCODE(FAST_INDEX_0) // a[0] //OPCODE(FAST_INDEX_1) // a[i] diff --git a/src/parser.h b/src/parser.h index e7740ed1..11e25d3e 100644 --- a/src/parser.h +++ b/src/parser.h @@ -12,7 +12,7 @@ constexpr const char* kTokens[] = { "==", "!=", ">=", "<=", "+=", "-=", "*=", "/=", "//=", "%=", "&=", "|=", "^=", /** KW_BEGIN **/ - "class", "import", "as", "def", "lambda", "pass", "del", "from", "with", + "class", "import", "as", "def", "lambda", "pass", "del", "from", "with", "yield", "None", "in", "is", "and", "or", "not", "True", "False", "global", "try", "except", "finally", "goto", "label", // extended keywords, not available in cpython "while", "for", "if", "elif", "else", "break", "continue", "return", "assert", "raise", diff --git a/src/vm.h b/src/vm.h index d1aa9ed2..b344180d 100644 --- a/src/vm.h +++ b/src/vm.h @@ -13,17 +13,20 @@ } // static std::map _stats; +class Generator; class VM { +public: std::stack< std::unique_ptr > callstack; PyVar _py_op_call; + PyVar _py_op_yield; // PyVar _ascii_str_pool[128]; PyVar run_frame(Frame* frame){ while(frame->has_next_bytecode()){ const Bytecode& byte = frame->next_bytecode(); // if(true || frame->_module != builtins){ - // printf("%d: %s (%d) %s\n", frame->_ip, OP_NAMES[byte.op], byte.arg, frame->stack_info().c_str()); + // printf("%d: %s (%d) %s\n", frame->_ip, OP_NAMES[byte.op], byte.arg, frame->stack_info().c_str()); // } switch (byte.op) { @@ -246,23 +249,28 @@ class VM { case OP_GET_ITER: { PyVar obj = frame->pop_value(this); - PyVarOrNull iter_fn = getattr(obj, __iter__, false); - if(iter_fn != nullptr){ - PyVar tmp = call(iter_fn); - PyVarRef var = frame->pop(); - check_type(var, tp_ref); - PyIter_AS_C(tmp)->var = var; - frame->push(std::move(tmp)); + PyVar iter_obj = nullptr; + if(!obj->is_type(tp_native_iterator)){ + PyVarOrNull iter_f = getattr(obj, __iter__, false); + if(iter_f != nullptr) iter_obj = call(iter_f); }else{ + iter_obj = obj; + } + if(iter_obj == nullptr){ TypeError(OBJ_NAME(_t(obj)).escape(true) + " object is not iterable"); } + PyVarRef var = frame->pop(); + check_type(var, tp_ref); + PyIter_AS_C(iter_obj)->var = var; + frame->push(std::move(iter_obj)); } break; case OP_FOR_ITER: { // top() must be PyIter, so no need to try_deref() auto& it = PyIter_AS_C(frame->top()); - if(it->has_next()){ - PyRef_AS_C(it->var)->set(this, frame, it->next()); + PyVar obj = it->next(); + if(obj != nullptr){ + PyRef_AS_C(it->var)->set(this, frame, std::move(obj)); }else{ int blockEnd = frame->co->blocks[byte.block].end; frame->jump_abs_safe(blockEnd); @@ -319,6 +327,7 @@ class VM { frame->push(it->second); } } break; + case OP_YIELD_VALUE: return _py_op_yield; // TODO: using "goto" inside with block may cause __exit__ not called case OP_WITH_ENTER: call(frame->pop_value(this), __enter__); break; case OP_WITH_EXIT: call(frame->pop_value(this), __exit__); break; @@ -339,7 +348,6 @@ class VM { return None; } -public: pkpy::NameDict _types; pkpy::NameDict _modules; // loaded modules emhash8::HashMap _lazy_modules; // lazy loaded modules @@ -502,13 +510,16 @@ public: locals[key] = val; } - PyVar* it_m = (*callable)->attr().try_get(__module__); - PyVar _module = it_m != nullptr ? *it_m : top_frame()->_module; - if(opCall){ - _new_frame(fn->code, _module, _locals); - return _py_op_call; + PyVar* _m = (*callable)->attr().try_get(__module__); + PyVar _module = _m != nullptr ? *_m : top_frame()->_module; + auto _frame = _new_frame(fn->code, _module, _locals); + if(fn->code->is_generator){ + return PyIter(pkpy::make_shared( + this, std::move(_frame))); } - return _exec(fn->code, _module, _locals); + callstack.push(std::move(_frame)); + if(opCall) return _py_op_call; + return _exec(); } TypeError("'" + OBJ_NAME(_t(*callable)) + "' object is not callable"); return None; @@ -533,17 +544,21 @@ public: } template - Frame* _new_frame(Args&&... args){ + inline std::unique_ptr _new_frame(Args&&... args){ if(callstack.size() > maxRecursionDepth){ _error("RecursionError", "maximum recursion depth exceeded"); } - callstack.emplace(std::make_unique(std::forward(args)...)); - return callstack.top().get(); + return std::make_unique(std::forward(args)...); } template - PyVar _exec(Args&&... args){ - Frame* frame = _new_frame(std::forward(args)...); + inline PyVar _exec(Args&&... args){ + callstack.push(_new_frame(std::forward(args)...)); + return _exec(); + } + + PyVar _exec(){ + Frame* frame = top_frame(); i64 base_id = frame->id; PyVar ret = nullptr; bool need_raise = false; @@ -553,7 +568,7 @@ public: try{ if(need_raise){ need_raise = false; _raise(); } ret = run_frame(frame); - + if(ret == _py_op_yield) return _py_op_yield; if(ret != _py_op_call){ if(frame->id == base_id){ // [ frameBase<- ] callstack.pop(); @@ -884,6 +899,7 @@ public: this->builtins = new_module("builtins"); this->_main = new_module("__main__"); this->_py_op_call = new_object(_new_type_object("_internal"), DUMMY_VAL); + this->_py_op_yield = new_object(_new_type_object("_internal"), DUMMY_VAL); setattr(_t(tp_type), __base__, _t(tp_object)); setattr(_t(tp_object), __base__, None); diff --git a/tests/_yield.py b/tests/_yield.py new file mode 100644 index 00000000..4bb0d8c1 --- /dev/null +++ b/tests/_yield.py @@ -0,0 +1,7 @@ +def f(n): + for i in range(n): + yield i + +a = [i for i in f(6)] + +assert a == [0,1,2,3,4,5] \ No newline at end of file