diff --git a/docs/modules/sys.md b/docs/modules/sys.md index 005bc77f..fbccead0 100644 --- a/docs/modules/sys.md +++ b/docs/modules/sys.md @@ -20,3 +20,11 @@ May be one of: ### `sys.argv` The command line arguments. Set by `py_sys_setargv`. + +### `sys.setrecursionlimit(limit: int)` + +Set the maximum depth of the Python interpreter stack to `limit`. This limit prevents infinite recursion from causing an overflow of the C stack and crashing the interpreter. + +### `sys.getrecursionlimit() -> int` + +Return the current value of the recursion limit. diff --git a/include/pocketpy/interpreter/vm.h b/include/pocketpy/interpreter/vm.h index f40c1174..d4ed4686 100644 --- a/include/pocketpy/interpreter/vm.h +++ b/include/pocketpy/interpreter/vm.h @@ -38,6 +38,10 @@ typedef struct VM { py_TValue last_retval; py_TValue curr_exception; + + int recursion_depth; + int max_recursion_depth; + bool is_curr_exc_handled; // handled by try-except block but not cleared yet py_TValue reg[8]; // users' registers diff --git a/include/pocketpy/pocketpy.h b/include/pocketpy/pocketpy.h index bf702517..90e02ae1 100644 --- a/include/pocketpy/pocketpy.h +++ b/include/pocketpy/pocketpy.h @@ -741,7 +741,7 @@ enum py_PredefinedType { tp_KeyboardInterrupt, tp_StopIteration, tp_SyntaxError, - tp_StackOverflowError, + tp_RecursionError, tp_OSError, tp_NotImplementedError, tp_TypeError, diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 90f9426f..56a67518 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -86,8 +86,8 @@ FrameResult VM__run_top_frame(VM* self) { while(true) { Bytecode byte; __NEXT_FRAME: - if(self->stack.sp > self->stack.end) { - py_exception(tp_StackOverflowError, ""); + if(self->recursion_depth >= self->max_recursion_depth) { + py_exception(tp_RecursionError, "maximum recursion depth exceeded"); goto __ERROR; } @@ -403,7 +403,7 @@ FrameResult VM__run_top_frame(VM* self) { if(!py_callcfunc(magic->_cfunc, 3, THIRD())) goto __ERROR; STACK_SHRINK(4); } else { - *FOURTH() = *magic; // [__selitem__, a, b, val] + *FOURTH() = *magic; // [__setitem__, a, b, val] if(!py_vectorcall(2, 0)) goto __ERROR; } DISPATCH(); diff --git a/src/interpreter/generator.c b/src/interpreter/generator.c index 5ca2cb6d..e5542736 100644 --- a/src/interpreter/generator.c +++ b/src/interpreter/generator.c @@ -64,6 +64,7 @@ static bool generator__next__(int argc, py_Ref argv) { } vm->stack.sp = ud->frame->p0; vm->top_frame = vm->top_frame->f_back; + vm->recursion_depth--; ud->state = 1; return true; } else { diff --git a/src/interpreter/vm.c b/src/interpreter/vm.c index e0900bc9..9c657ad2 100644 --- a/src/interpreter/vm.c +++ b/src/interpreter/vm.c @@ -71,6 +71,10 @@ void VM__ctor(VM* self) { self->last_retval = *py_NIL(); self->curr_exception = *py_NIL(); + + self->recursion_depth = 0; + self->max_recursion_depth = 1000; + self->is_curr_exc_handled = false; self->ctx = NULL; @@ -162,7 +166,7 @@ void VM__ctor(VM* self) { py_setdict(&self->builtins, py_name("StopIteration"), py_tpobject(tp_StopIteration)); INJECT_BUILTIN_EXC(SyntaxError, tp_Exception); - INJECT_BUILTIN_EXC(StackOverflowError, tp_Exception); + INJECT_BUILTIN_EXC(RecursionError, tp_Exception); INJECT_BUILTIN_EXC(OSError, tp_Exception); INJECT_BUILTIN_EXC(NotImplementedError, tp_Exception); INJECT_BUILTIN_EXC(TypeError, tp_Exception); @@ -265,6 +269,7 @@ void VM__dtor(VM* self) { void VM__push_frame(VM* self, py_Frame* frame) { frame->f_back = self->top_frame; self->top_frame = frame; + self->recursion_depth++; if(self->trace_info.func) self->trace_info.func(frame, TRACE_EVENT_PUSH); } @@ -277,6 +282,7 @@ void VM__pop_frame(VM* self) { // pop frame and delete self->top_frame = frame->f_back; Frame__delete(frame); + self->recursion_depth--; } static void _clip_int(int* value, int min, int max) { @@ -469,12 +475,6 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall py_Ref argv = p0 + 1 + (int)py_isnil(p0 + 1); if(p0->type == tp_function) { - // check stack overflow - if(self->stack.sp > self->stack.end) { - py_exception(tp_StackOverflowError, ""); - return RES_ERROR; - } - Function* fn = py_touserdata(p0); const CodeObject* co = &fn->decl->code; diff --git a/src/modules/os.c b/src/modules/os.c index 5de8a018..38cc3a78 100644 --- a/src/modules/os.c +++ b/src/modules/os.c @@ -240,9 +240,28 @@ void pk__add_module_io() {} #endif +static bool sys_setrecursionlimit(int argc, py_Ref argv) { + PY_CHECK_ARGC(1); + PY_CHECK_ARG_TYPE(0, tp_int); + int limit = py_toint(py_arg(0)); + if(limit <= pk_current_vm->recursion_depth) return ValueError("the limit is too low"); + pk_current_vm->max_recursion_depth = limit; + py_newnone(py_retval()); + return true; +} + +static bool sys_getrecursionlimit(int argc, py_Ref argv) { + PY_CHECK_ARGC(0); + py_newint(py_retval(), pk_current_vm->max_recursion_depth); + return true; +} + void pk__add_module_sys() { py_Ref mod = py_newmodule("sys"); py_newstr(py_emplacedict(mod, py_name("platform")), PY_SYS_PLATFORM_STRING); py_newstr(py_emplacedict(mod, py_name("version")), PK_VERSION); py_newlist(py_emplacedict(mod, py_name("argv"))); + + py_bindfunc(mod, "setrecursionlimit", sys_setrecursionlimit); + py_bindfunc(mod, "getrecursionlimit", sys_getrecursionlimit); }