Skip to content

Commit 7ea925d

Browse files
authored
[mypyc] Fix exception swallowing in async try/finally blocks with await (#19353)
When a try/finally block in an async function contains an await statement in the finally block, exceptions raised in the try block are silently swallowed if a context switch occurs. This happens because mypyc stores exception information in registers that don't survive across await points. The Problem: - mypyc's transform_try_finally_stmt uses error_catch_op to save exceptions - to a register, then reraise_exception_op to restore from that register - When await causes a context switch, register values are lost - The exception information is gone, causing silent exception swallowing The Solution: - Add new transform_try_finally_stmt_async for async-aware exception handling - Use sys.exc_info() to preserve exceptions across context switches instead - of registers - Check error indicator first to handle new exceptions raised in finally - Route to async version when finally block contains await expressions Implementation Details: - transform_try_finally_stmt_async uses get_exc_info_op/restore_exc_info_op - which work with sys.exc_info() that survives context switches - Proper exception priority: new exceptions in finally replace originals - Added has_await_in_block helper to detect await expressions Test Coverage: Added comprehensive async exception handling tests: - testAsyncTryExceptFinallyAwait: 8 test cases covering various scenarios - Simple try/finally with exception and await - Exception caught but not re-raised - Exception caught and re-raised - Different exception raised in except - Try/except inside finally block - Try/finally inside finally block - Control case without await - Normal flow without exceptions - testAsyncContextManagerExceptionHandling: Verifies async with still works - Basic exception propagation - Exception in **aexit** replacing original See mypyc/mypyc#1114.
1 parent 02c9766 commit 7ea925d

File tree

2 files changed

+346
-2
lines changed

2 files changed

+346
-2
lines changed

mypyc/irbuild/statement.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections.abc import Sequence
1313
from typing import Callable
1414

15+
import mypy.nodes
1516
from mypy.nodes import (
1617
ARG_NAMED,
1718
ARG_POS,
@@ -104,6 +105,7 @@
104105
get_exc_info_op,
105106
get_exc_value_op,
106107
keep_propagating_op,
108+
no_err_occurred_op,
107109
propagate_if_error_op,
108110
raise_exception_op,
109111
reraise_exception_op,
@@ -683,7 +685,7 @@ def try_finally_resolve_control(
683685

684686

685687
def transform_try_finally_stmt(
686-
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc
688+
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc, line: int = -1
687689
) -> None:
688690
"""Generalized try/finally handling that takes functions to gen the bodies.
689691
@@ -719,6 +721,118 @@ def transform_try_finally_stmt(
719721
builder.activate_block(out_block)
720722

721723

724+
def transform_try_finally_stmt_async(
725+
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc, line: int = -1
726+
) -> None:
727+
"""Async-aware try/finally handling for when finally contains await.
728+
729+
This version uses a modified approach that preserves exceptions across await."""
730+
731+
# We need to handle returns properly, so we'll use TryFinallyNonlocalControl
732+
# to track return values, similar to the regular try/finally implementation
733+
734+
err_handler, main_entry, return_entry, finally_entry = (
735+
BasicBlock(),
736+
BasicBlock(),
737+
BasicBlock(),
738+
BasicBlock(),
739+
)
740+
741+
# Track if we're returning from the try block
742+
control = TryFinallyNonlocalControl(return_entry)
743+
builder.builder.push_error_handler(err_handler)
744+
builder.nonlocal_control.append(control)
745+
builder.goto_and_activate(BasicBlock())
746+
try_body()
747+
builder.goto(main_entry)
748+
builder.nonlocal_control.pop()
749+
builder.builder.pop_error_handler()
750+
ret_reg = control.ret_reg
751+
752+
# Normal case - no exception or return
753+
builder.activate_block(main_entry)
754+
builder.goto(finally_entry)
755+
756+
# Return case
757+
builder.activate_block(return_entry)
758+
builder.goto(finally_entry)
759+
760+
# Exception case - need to catch to clear the error indicator
761+
builder.activate_block(err_handler)
762+
# Catch the error to clear Python's error indicator
763+
builder.call_c(error_catch_op, [], line)
764+
# We're not going to use old_exc since it won't survive await
765+
# The exception is now in sys.exc_info()
766+
builder.goto(finally_entry)
767+
768+
# Finally block
769+
builder.activate_block(finally_entry)
770+
771+
# Execute finally body
772+
finally_body()
773+
774+
# After finally, we need to handle exceptions carefully:
775+
# 1. If finally raised a new exception, it's in the error indicator - let it propagate
776+
# 2. If finally didn't raise, check if we need to reraise the original from sys.exc_info()
777+
# 3. If there was a return, return that value
778+
# 4. Otherwise, normal exit
779+
780+
# First, check if there's a current exception in the error indicator
781+
# (this would be from the finally block)
782+
no_current_exc = builder.call_c(no_err_occurred_op, [], line)
783+
finally_raised = BasicBlock()
784+
check_original = BasicBlock()
785+
builder.add(Branch(no_current_exc, check_original, finally_raised, Branch.BOOL))
786+
787+
# Finally raised an exception - let it propagate naturally
788+
builder.activate_block(finally_raised)
789+
builder.call_c(keep_propagating_op, [], NO_TRACEBACK_LINE_NO)
790+
builder.add(Unreachable())
791+
792+
# No exception from finally, check if we need to handle return or original exception
793+
builder.activate_block(check_original)
794+
795+
# Check if we have a return value
796+
if ret_reg:
797+
return_block, check_old_exc = BasicBlock(), BasicBlock()
798+
builder.add(Branch(builder.read(ret_reg), check_old_exc, return_block, Branch.IS_ERROR))
799+
800+
builder.activate_block(return_block)
801+
builder.nonlocal_control[-1].gen_return(builder, builder.read(ret_reg), -1)
802+
803+
builder.activate_block(check_old_exc)
804+
805+
# Check if we need to reraise the original exception from sys.exc_info
806+
exc_info = builder.call_c(get_exc_info_op, [], line)
807+
exc_type = builder.add(TupleGet(exc_info, 0, line))
808+
809+
# Check if exc_type is None
810+
none_obj = builder.none_object()
811+
has_exc = builder.binary_op(exc_type, none_obj, "is not", line)
812+
813+
reraise_block, exit_block = BasicBlock(), BasicBlock()
814+
builder.add(Branch(has_exc, reraise_block, exit_block, Branch.BOOL))
815+
816+
# Reraise the original exception
817+
builder.activate_block(reraise_block)
818+
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
819+
builder.add(Unreachable())
820+
821+
# Normal exit
822+
builder.activate_block(exit_block)
823+
824+
825+
# A simple visitor to detect await expressions
826+
class AwaitDetector(mypy.traverser.TraverserVisitor):
827+
def __init__(self) -> None:
828+
super().__init__()
829+
self.has_await = False
830+
831+
def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> None:
832+
self.has_await = True
833+
super().visit_await_expr(o)
834+
835+
722836
def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
723837
# Our compilation strategy for try/except/else/finally is to
724838
# treat try/except/else and try/finally as separate language
@@ -727,6 +841,17 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
727841
# body of a try/finally block.
728842
if t.is_star:
729843
builder.error("Exception groups and except* cannot be compiled yet", t.line)
844+
845+
# Check if we're in an async function with a finally block that contains await
846+
use_async_version = False
847+
if t.finally_body and builder.fn_info.is_coroutine:
848+
detector = AwaitDetector()
849+
t.finally_body.accept(detector)
850+
851+
if detector.has_await:
852+
# Use the async version that handles exceptions correctly
853+
use_async_version = True
854+
730855
if t.finally_body:
731856

732857
def transform_try_body() -> None:
@@ -737,7 +862,14 @@ def transform_try_body() -> None:
737862

738863
body = t.finally_body
739864

740-
transform_try_finally_stmt(builder, transform_try_body, lambda: builder.accept(body))
865+
if use_async_version:
866+
transform_try_finally_stmt_async(
867+
builder, transform_try_body, lambda: builder.accept(body), t.line
868+
)
869+
else:
870+
transform_try_finally_stmt(
871+
builder, transform_try_body, lambda: builder.accept(body), t.line
872+
)
741873
else:
742874
transform_try_except_stmt(builder, t)
743875

@@ -828,6 +960,7 @@ def finally_body() -> None:
828960
builder,
829961
lambda: transform_try_except(builder, try_body, [(None, None, except_body)], None, line),
830962
finally_body,
963+
line,
831964
)
832965

833966

0 commit comments

Comments
 (0)