diff --git a/src/transformation/visitors/errors.ts b/src/transformation/visitors/errors.ts index cc73b2ac0..77067dda8 100644 --- a/src/transformation/visitors/errors.ts +++ b/src/transformation/visitors/errors.ts @@ -40,35 +40,44 @@ const transformAsyncTry: FunctionVisitor = (statement, context) let catchScope: Scope | undefined; const chainCalls: lua.Statement[] = []; - if (statement.finallyBlock) { - const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally")); - const finallyFunction = lua.createFunctionExpression( - lua.createBlock(context.transformStatements(statement.finallyBlock.statements)) - ); - const finallyCall = lua.createCallExpression( - awaiterFinally, - [awaiterIdentifier, finallyFunction], - statement.finallyBlock - ); - chainCalls.push(lua.createExpressionStatement(finallyCall)); - } - if (statement.catchClause) { + // ____try = ____try.catch() const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause); catchScope = cScope; if (catchFunction.params) { catchFunction.params.unshift(lua.createAnonymousIdentifier()); } + const catchBodyStatements = catchFunction.body ? catchFunction.body.statements : []; + const asyncWrappedCatch = wrapInAsyncAwaiter(context, [...catchBodyStatements], false); + catchFunction.body = lua.createBlock([lua.createReturnStatement([asyncWrappedCatch])]); + const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch")); const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]); - const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall); - chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); - } else { - const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier); - chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); + chainCalls.push(lua.createAssignmentStatement(lua.cloneIdentifier(awaiterIdentifier), catchCall)); } + if (statement.finallyBlock) { + // ____try = ____try.finally() + const finallyStatements = context.transformStatements(statement.finallyBlock.statements); + const asyncWrappedFinally = wrapInAsyncAwaiter(context, finallyStatements, false); + const finallyFunction = lua.createFunctionExpression( + lua.createBlock([lua.createReturnStatement([asyncWrappedFinally])]) + ); + + const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally")); + const finallyCall = lua.createCallExpression( + awaiterFinally, + [awaiterIdentifier, finallyFunction], + statement.finallyBlock + ); + chainCalls.push(lua.createAssignmentStatement(lua.cloneIdentifier(awaiterIdentifier), finallyCall)); + } + + // __TS__Await(____try) + const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier); + chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); + const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn; const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak; const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue; diff --git a/test/unit/builtins/async-await.spec.ts b/test/unit/builtins/async-await.spec.ts index 7a511d7b1..546589355 100644 --- a/test/unit/builtins/async-await.spec.ts +++ b/test/unit/builtins/async-await.spec.ts @@ -816,6 +816,120 @@ describe("try/catch in async function", () => { }); }); + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1659 + test("await inside catch handler resolves correctly (#1659)", () => { + util.testFunction` + let reject: (reason: string) => void = () => {}; + + async function failing() { + return new Promise((_, rej) => { reject = rej; }); + } + + async function run() { + try { + await failing(); + } catch (e) { + log("catch"); + const a = await Promise.resolve(true); + log("a", a); + } + } + + run(); + reject("error"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["catch", "a", true]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1659 + test("await inside finally handler resolves correctly (#1659)", () => { + util.testFunction` + let reject: (reason: string) => void = () => {}; + + async function failing() { + return new Promise((_, rej) => { reject = rej; }); + } + + async function run() { + try { + await failing(); + } finally { + log("finally"); + const a = await Promise.resolve(true); + log("a", a); + } + } + + run().catch(() => {}); + reject("error"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["finally", "a", true]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1659 + test("await inside both catch and finally handlers (#1659)", () => { + util.testFunction` + let reject: (reason: string) => void = () => {}; + + async function failing() { + return new Promise((_, rej) => { reject = rej; }); + } + + async function run() { + try { + await failing(); + } catch (e) { + log("catch"); + const a = await Promise.resolve("caught"); + log("a", a); + } finally { + log("finally"); + const b = await Promise.resolve("done"); + log("b", b); + } + } + + run(); + reject("error"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["catch", "a", "caught", "finally", "b", "done"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1659 + test("awaited value in catch is returned from async function (#1659)", () => { + util.testFunction` + const failing = defer(); + const recovery = defer(); + + async function run() { + try { + await failing.promise; + return "succeeded"; + } catch (e) { + return await recovery.promise; + } + } + + run().then(value => log("result", value)); + + failing.reject("error"); + recovery.resolve("recovered"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["result", "recovered"]); + }); + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 test("return inside try with deferred promise (#1706)", () => { util.testFunction`