From 0912e88ac792ad1605c0efd167724c4016bd1bdf Mon Sep 17 00:00:00 2001 From: blueloveTH Date: Tue, 2 May 2023 21:20:01 +0800 Subject: [PATCH] ... --- src/ceval.h | 40 +++++++++++++++------------------------- src/pocketpy.h | 3 +-- src/str.h | 1 + src/vm.h | 15 ++++++--------- tests/28_iter.py | 26 ++++++++++++++++++++++++++ 5 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/ceval.h b/src/ceval.h index df63ed2b..f14a5373 100644 --- a/src/ceval.h +++ b/src/ceval.h @@ -408,22 +408,15 @@ __NEXT_STEP:; /*****************************************/ TARGET(GET_ITER) TOP() = asIter(TOP()); - check_type(TOP(), tp_iterator); DISPATCH(); - TARGET(FOR_ITER) { -#if DEBUG_EXTRA_CHECK - BaseIter* it = PyIter_AS_C(TOP()); -#else - BaseIter* it = _PyIter_AS_C(TOP()); -#endif - PyObject* obj = it->next(); - if(obj != StopIteration){ - PUSH(obj); + TARGET(FOR_ITER) + _0 = PyIterNext(TOP()); + if(_0 != StopIteration){ + PUSH(_0); }else{ - int target = co_blocks[byte.block].end; - frame->jump_abs_break(target); + frame->jump_abs_break(co_blocks[byte.block].end); } - } DISPATCH(); + DISPATCH(); /*****************************************/ TARGET(IMPORT_NAME) { StrName name(byte.arg); @@ -459,12 +452,10 @@ __NEXT_STEP:; /*****************************************/ TARGET(UNPACK_SEQUENCE) TARGET(UNPACK_EX) { - // asIter or iter->next may run bytecode, accidential gc may happen auto _lock = heap.gc_scope_lock(); // lock the gc via RAII!! - PyObject* obj = asIter(POPX()); - BaseIter* iter = PyIter_AS_C(obj); + PyObject* iter = asIter(POPX()); for(int i=0; inext(); + PyObject* item = PyIterNext(iter); if(item == StopIteration) ValueError("not enough values to unpack"); PUSH(item); } @@ -472,23 +463,22 @@ __NEXT_STEP:; if(byte.op == OP_UNPACK_EX){ List extras; while(true){ - PyObject* item = iter->next(); + PyObject* item = PyIterNext(iter); if(item == StopIteration) break; extras.push_back(item); } PUSH(VAR(extras)); }else{ - if(iter->next() != StopIteration) ValueError("too many values to unpack"); + if(PyIterNext(iter) != StopIteration) ValueError("too many values to unpack"); } } DISPATCH(); TARGET(UNPACK_UNLIMITED) { auto _lock = heap.gc_scope_lock(); // lock the gc via RAII!! - PyObject* obj = asIter(POPX()); - BaseIter* iter = PyIter_AS_C(obj); - obj = iter->next(); - while(obj != StopIteration){ - PUSH(obj); - obj = iter->next(); + PyObject* iter = asIter(POPX()); + _0 = PyIterNext(iter); + while(_0 != StopIteration){ + PUSH(_0); + _0 = PyIterNext(iter); } } DISPATCH(); /*****************************************/ diff --git a/src/pocketpy.h b/src/pocketpy.h index d563326d..0bfa339b 100644 --- a/src/pocketpy.h +++ b/src/pocketpy.h @@ -166,8 +166,7 @@ inline void init_builtins(VM* _vm) { }); _vm->bind_builtin_func<1>("next", [](VM* vm, ArgsView args) { - BaseIter* iter = vm->PyIter_AS_C(args[0]); - return iter->next(); + return vm->PyIterNext(args[0]); }); _vm->bind_builtin_func<1>("dir", [](VM* vm, ArgsView args) { diff --git a/src/str.h b/src/str.h index bbac8fb4..0f458e27 100644 --- a/src/str.h +++ b/src/str.h @@ -383,6 +383,7 @@ const StrName __class__ = StrName::get("__class__"); const StrName __base__ = StrName::get("__base__"); const StrName __new__ = StrName::get("__new__"); const StrName __iter__ = StrName::get("__iter__"); +const StrName __next__ = StrName::get("__next__"); const StrName __str__ = StrName::get("__str__"); const StrName __repr__ = StrName::get("__repr__"); const StrName __getitem__ = StrName::get("__getitem__"); diff --git a/src/vm.h b/src/vm.h index 3d6db7eb..9de8b734 100644 --- a/src/vm.h +++ b/src/vm.h @@ -321,15 +321,12 @@ public: return heap.gcnew

(tp_iterator, std::forward

(value)); } - BaseIter* PyIter_AS_C(PyObject* obj) - { - check_type(obj, tp_iterator); - return static_cast(obj->value()); - } - - BaseIter* _PyIter_AS_C(PyObject* obj) - { - return static_cast(obj->value()); + PyObject* PyIterNext(PyObject* obj){ + if(is_non_tagged_type(obj, tp_iterator)){ + BaseIter* iter = static_cast(obj->value()); + return iter->next(); + } + return call_method(obj, __next__); } /***** Error Reporter *****/ diff --git a/tests/28_iter.py b/tests/28_iter.py index 49786af6..e75c7e9a 100644 --- a/tests/28_iter.py +++ b/tests/28_iter.py @@ -10,3 +10,29 @@ while True: total += obj assert total == 6 + +class Task: + def __init__(self, n): + self.n = n + + def __iter__(self): + self.i = 0 + return self + + def __next__(self): + if self.i == self.n: + return StopIteration + self.i += 1 + return self.i + +a = Task(3) +assert sum(a) == 6 + +i = iter(Task(5)) +assert next(i) == 1 +assert next(i) == 2 +assert next(i) == 3 +assert next(i) == 4 +assert next(i) == 5 +assert next(i) == StopIteration +assert next(i) == StopIteration