diff --git a/src/compiler/compiler.c b/src/compiler/compiler.c index 2702f82c..8717f643 100644 --- a/src/compiler/compiler.c +++ b/src/compiler/compiler.c @@ -2798,6 +2798,8 @@ static Error* compile_stmt(Compiler* self) { case TK_WITH: { check(EXPR(self)); // [ ] Ctx__s_emit_top(ctx()); + // Save context manager for later __exit__ call + Ctx__emit_(ctx(), OP_DUP_TOP, BC_NOARG, prev()->line); Ctx__enter_block(ctx(), CodeBlockType_WITH); NameExpr* as_name = NULL; if(match(TK_AS)) { @@ -2806,17 +2808,33 @@ static Error* compile_stmt(Compiler* self) { as_name = NameExpr__new(prev()->line, name, name_scope(self)); } Ctx__emit_(ctx(), OP_WITH_ENTER, BC_NOARG, prev()->line); - // [ .__enter__() ] if(as_name) { bool ok = vtemit_store((Expr*)as_name, ctx()); vtdelete((Expr*)as_name); if(!ok) return SyntaxError(self, "invalid syntax"); } else { - // discard `__enter__()`'s return value Ctx__emit_(ctx(), OP_POP_TOP, BC_NOARG, BC_KEEPLINE); } + // Wrap body in try-except to ensure __exit__ is called even on exception + Ctx__enter_block(ctx(), CodeBlockType_TRY); + Ctx__emit_(ctx(), OP_BEGIN_TRY, BC_NOARG, prev()->line); check(compile_block_body(self)); + Ctx__emit_(ctx(), OP_END_TRY, BC_NOARG, BC_KEEPLINE); + // Normal exit: call __exit__(None, None, None) + Ctx__emit_(ctx(), OP_LOAD_NONE, BC_NOARG, prev()->line); + Ctx__emit_(ctx(), OP_LOAD_NONE, BC_NOARG, prev()->line); + Ctx__emit_(ctx(), OP_LOAD_NONE, BC_NOARG, prev()->line); Ctx__emit_(ctx(), OP_WITH_EXIT, BC_NOARG, prev()->line); + int jump_patch = Ctx__emit_(ctx(), OP_JUMP_FORWARD, BC_NOARG, BC_KEEPLINE); + Ctx__exit_block(ctx()); + // Exception handler: call __exit__ with exception info, then re-raise + Ctx__emit_(ctx(), OP_PUSH_EXCEPTION, BC_NOARG, BC_KEEPLINE); + Ctx__emit_(ctx(), OP_LOAD_NONE, BC_NOARG, BC_KEEPLINE); // exc_type + Ctx__emit_(ctx(), OP_ROT_TWO, BC_NOARG, BC_KEEPLINE); // reorder: [cm, None, exc] + Ctx__emit_(ctx(), OP_LOAD_NONE, BC_NOARG, BC_KEEPLINE); // exc_tb + Ctx__emit_(ctx(), OP_WITH_EXIT, BC_NOARG, prev()->line); + Ctx__emit_(ctx(), OP_RE_RAISE, BC_NOARG, BC_KEEPLINE); + Ctx__patch_jump(ctx(), jump_patch); Ctx__exit_block(ctx()); } break; /*************************************************/ diff --git a/src/interpreter/ceval.c b/src/interpreter/ceval.c index 6076464b..95e0dc85 100644 --- a/src/interpreter/ceval.c +++ b/src/interpreter/ceval.c @@ -1122,14 +1122,35 @@ __NEXT_STEP: DISPATCH(); } case OP_WITH_EXIT: { - // [expr] - py_push(TOP()); + // Stack: [cm, exc_type, exc_val, exc_tb] + // Call cm.__exit__(exc_type, exc_val, exc_tb) + py_Ref exc_tb = TOP(); + py_Ref exc_val = SECOND(); + py_Ref exc_type = THIRD(); + py_Ref cm = FOURTH(); + + // Save all values from stack + py_TValue saved_cm = *cm; + py_TValue saved_exc_type = *exc_type; + py_TValue saved_exc_val = *exc_val; + py_TValue saved_exc_tb = *exc_tb; + self->stack.sp -= 4; + + // Push cm and get __exit__ method + py_push(&saved_cm); if(!py_pushmethod(__exit__)) { - TypeError("'%t' object does not support the context manager protocol", TOP()->type); + TypeError("'%t' object does not support the context manager protocol", saved_cm.type); goto __ERROR; } - if(!py_vectorcall(0, 0)) goto __ERROR; - POP(); + + // Push arguments: exc_type, exc_val, exc_tb + PUSH(&saved_exc_type); + PUSH(&saved_exc_val); + PUSH(&saved_exc_tb); + + // Call __exit__(exc_type, exc_val, exc_tb) + if(!py_vectorcall(3, 0)) goto __ERROR; + py_pop(); // discard return value DISPATCH(); } /////////// diff --git a/tests/520_context.py b/tests/520_context.py index 53d22336..5fc0ad30 100644 --- a/tests/520_context.py +++ b/tests/520_context.py @@ -27,4 +27,29 @@ assert path == ['enter', 'in', 'exit'] path.clear() +# Test that __exit__ is called even when an exception occurs +class B: + def __init__(self): + self.path = [] + + def __enter__(self): + path.append('enter') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + path.append('exit') + if exc_type is not None: + path.append('exception') + return False # propagate exception + +try: + with B(): + path.append('before_raise') + raise ValueError('test') + path.append('after_raise') # should not be reached +except ValueError: + pass + +assert path == ['enter', 'before_raise', 'exit', 'exception'], f"Expected ['enter', 'before_raise', 'exit', 'exception'], got {path}" +