Skip to content

Commit ee7bbb4

Browse files
committed
fix(convertConditionalReturns): several bugs
fix #2 fix #3
1 parent e5aa05e commit ee7bbb4

25 files changed

+812
-160
lines changed

src/util/canUnwindAsIs.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import * as t from '@babel/types'
2+
import { NodePath } from '@babel/traverse'
3+
4+
export default function canUnwindAsIs(
5+
path: NodePath<t.CallExpression>
6+
): boolean {
7+
let parent: NodePath<any> = path.parentPath
8+
let child: NodePath<any> = path
9+
while (parent && !parent.isFunction()) {
10+
if (parent.isBlockStatement()) {
11+
const body = (parent as NodePath<t.BlockStatement>).get('body')
12+
if (child !== body[body.length - 1]) return false
13+
} else if (parent.isLoop()) {
14+
return false
15+
} else if (parent.isSwitchCase()) {
16+
if ((parent.node as t.SwitchCase).test != null) return false
17+
} else if (!parent.isStatement() && !parent.isAwaitExpression()) {
18+
return false
19+
}
20+
child = parent
21+
parent = child.parentPath
22+
}
23+
return true
24+
}

src/util/convertConditionalReturns.ts

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as t from '@babel/types'
22
import { NodePath } from '@babel/traverse'
3-
import restOfBlockStatement from './restOfBlockStatement'
3+
import replaceWithStatements from './replaceWithStatements'
4+
import removeRestOfBlockStatement from './removeRestOfBlockStatement'
45

56
function isLastStatementInBlock(path: NodePath<any>): boolean {
67
const { parentPath } = path
@@ -9,81 +10,74 @@ function isLastStatementInBlock(path: NodePath<any>): boolean {
910
return (path as NodePath<any>) === body[body.length - 1]
1011
}
1112

12-
function isInBranch<T extends t.Statement>(
13-
path: NodePath<T>,
14-
branch: 'consequent' | 'alternate'
15-
): boolean {
16-
const { parentPath } = path
17-
if (parentPath.isIfStatement())
18-
return (path as NodePath<any>) === parentPath.get(branch)
19-
if (parentPath.isBlockStatement()) {
20-
const grandparent = parentPath.parentPath
21-
return grandparent.isIfStatement() && parentPath === grandparent.get(branch)
13+
function hasReturn(path: NodePath<t.Statement>): boolean {
14+
if (path.isReturnStatement()) return true
15+
if (path.isBlockStatement()) {
16+
for (const child of (path as NodePath<t.BlockStatement>).get('body')) {
17+
if (child.isReturnStatement()) return true
18+
}
2219
}
2320
return false
2421
}
2522

26-
const isInConsequent = <T extends t.Statement>(path: NodePath<T>) =>
27-
isInBranch(path, 'consequent')
28-
const isInAlternate = <T extends t.Statement>(path: NodePath<T>) =>
29-
isInBranch(path, 'alternate')
30-
31-
function convertToBlockStatement(
32-
blockOrExpression: NodePath<any>
33-
): NodePath<t.BlockStatement> {
34-
if (blockOrExpression.isBlockStatement()) return blockOrExpression
35-
return (blockOrExpression.replaceWith(
36-
t.blockStatement(
37-
blockOrExpression.node == null
38-
? []
39-
: [
40-
blockOrExpression.isStatement()
41-
? blockOrExpression.node
42-
: t.expressionStatement(blockOrExpression.node),
43-
]
23+
function splitBranches(
24+
path: NodePath<t.IfStatement>
25+
): {
26+
returning: NodePath<t.Statement>[]
27+
notReturning: (NodePath<t.Statement> | NodePath<null>)[]
28+
} {
29+
const returning: NodePath<t.Statement>[] = []
30+
const notReturning: (NodePath<t.Statement> | NodePath<null>)[] = []
31+
let p: NodePath<t.IfStatement> | NodePath<null> = path
32+
while (p.isIfStatement()) {
33+
const consequent = (p as NodePath<t.IfStatement>).get('consequent')
34+
const alternate: NodePath<any> = (p as NodePath<t.IfStatement>).get(
35+
'alternate'
4436
)
45-
) as any)[0]
46-
}
37+
;(hasReturn(consequent) ? returning : notReturning).push(consequent)
38+
if (!alternate.isIfStatement()) {
39+
;(hasReturn(alternate) ? returning : notReturning).push(alternate)
40+
}
4741

48-
function addRestToConsequent<T extends t.Statement>(path: NodePath<T>): void {
49-
const ifStatement = path.findParent(p => p.isIfStatement())
50-
if (!ifStatement) throw new Error('failed to find parent IfStatement')
51-
const rest = restOfBlockStatement(ifStatement)
52-
if (!rest.length) return
53-
const consequent = (ifStatement as NodePath<t.IfStatement>).get('consequent')
54-
const restNodes = rest.map((path: NodePath<t.Statement>) => path.node)
55-
convertToBlockStatement(consequent).pushContainer('body', restNodes)
56-
rest.forEach((path: NodePath<t.Statement>) => path.remove())
57-
}
58-
59-
function addRestToAlternate<T extends t.Statement>(path: NodePath<T>): void {
60-
const ifStatement = path.findParent(p => p.isIfStatement())
61-
if (!ifStatement) throw new Error('failed to find parent IfStatement')
62-
const rest = restOfBlockStatement(ifStatement)
63-
if (!rest.length) return
64-
let alternate: NodePath<any> = ifStatement
65-
while (alternate.isIfStatement()) {
66-
alternate = (alternate as NodePath<t.IfStatement>).get('alternate')
42+
p = alternate
6743
}
68-
const restNodes = rest.map((path: NodePath<t.Statement>) => path.node)
69-
convertToBlockStatement(alternate).pushContainer('body', restNodes)
70-
rest.forEach((path: NodePath<t.Statement>) => path.remove())
44+
return { returning, notReturning }
7145
}
7246

7347
export default function convertConditionalReturns(
7448
parent: NodePath<t.BlockStatement>
7549
): boolean {
50+
let ifDepth = 0
7651
let isUnwindable = true
52+
const ifStatements: NodePath<t.IfStatement>[] = []
7753
const returnStatements: NodePath<t.ReturnStatement>[] = []
7854
parent.traverse(
7955
{
56+
IfStatement: {
57+
enter(path: NodePath<t.IfStatement>) {
58+
if (path.parentPath.isIfStatement()) return
59+
ifDepth++
60+
const { returning, notReturning } = splitBranches(path)
61+
if (returning.length > 0) {
62+
if (notReturning.length === 1) {
63+
ifStatements.push(path)
64+
} else if (notReturning.length > 1) {
65+
isUnwindable = false
66+
path.stop()
67+
return
68+
}
69+
}
70+
},
71+
exit(path: NodePath<t.IfStatement>) {
72+
if (path.parentPath.isIfStatement()) return
73+
ifDepth--
74+
},
75+
},
8076
ReturnStatement(path: NodePath<t.ReturnStatement>) {
8177
let { parentPath } = path
82-
let ifDepth = 0
8378
let loopDepth = 0
8479
while (parentPath && parentPath !== parent) {
85-
if (parentPath.isIfStatement()) ifDepth++
86-
else if (parentPath.isLoop()) loopDepth++
80+
if (parentPath.isLoop()) loopDepth++
8781
if (
8882
loopDepth > 1 ||
8983
(!isLastStatementInBlock(parentPath) &&
@@ -104,11 +98,15 @@ export default function convertConditionalReturns(
10498
parent.state
10599
)
106100
if (!isUnwindable) return false
107-
let returnStatement
108-
while ((returnStatement = returnStatements.pop())) {
109-
if (isInConsequent(returnStatement)) addRestToAlternate(returnStatement)
110-
else if (isInAlternate(returnStatement))
111-
addRestToConsequent(returnStatement)
101+
let ifStatement
102+
while ((ifStatement = ifStatements.pop())) {
103+
const {
104+
notReturning: [branch],
105+
} = splitBranches(ifStatement)
106+
const rest = removeRestOfBlockStatement(ifStatement)
107+
if (branch.isBlockStatement())
108+
(branch as NodePath<t.BlockStatement>).pushContainer('body', rest)
109+
else replaceWithStatements(branch, rest)
112110
}
113111
return true
114112
}

src/util/finalCleanup.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,36 @@ function unwrapPromiseResolves(
1717
return node as t.Expression
1818
}
1919

20+
function isEmptyBlock(path: NodePath<any>): boolean {
21+
return path.isBlockStatement() && path.node.body.length === 0
22+
}
23+
2024
export default function finalCleanup(path: NodePath<t.Function>): void {
2125
path.traverse(
2226
{
27+
IfStatement: {
28+
exit(path: NodePath<t.IfStatement>) {
29+
const consequent = path.get('consequent')
30+
const alternate = path.get('alternate')
31+
if (isEmptyBlock(consequent)) {
32+
if (alternate.node == null) {
33+
path.remove()
34+
} else if (isEmptyBlock(alternate)) {
35+
path.remove()
36+
} else {
37+
path.replaceWith(
38+
t.ifStatement(
39+
t.unaryExpression('!', path.node.test),
40+
alternate.node
41+
)
42+
)
43+
}
44+
} else if (isEmptyBlock(alternate)) {
45+
path.node.alternate = null
46+
alternate.remove()
47+
}
48+
},
49+
},
2350
AwaitExpression(path: NodePath<t.AwaitExpression>) {
2451
const argument = path.get('argument')
2552
const { parentPath } = path

src/util/findAwaitedExpression.ts

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/util/findNextLinkToUnwind.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import * as t from '@babel/types'
2+
import { NodePath } from '@babel/traverse'
3+
import { isPromiseMethodCall } from './predicates'
4+
5+
export default function findNextLinkToUnwind(
6+
paths: NodePath | NodePath[],
7+
except?: t.CallExpression
8+
): NodePath<t.CallExpression> | null {
9+
if (Array.isArray(paths)) {
10+
for (const path of paths) {
11+
const result = findNextLinkToUnwind(path, except)
12+
if (result) return result
13+
}
14+
return null
15+
}
16+
const path = paths
17+
if (path.node !== except && isPromiseMethodCall(path.node))
18+
return path as NodePath<t.CallExpression>
19+
let result: NodePath<t.CallExpression> | null = null
20+
path.traverse(
21+
{
22+
CallExpression(path: NodePath<t.CallExpression>) {
23+
if (path.node !== except && isPromiseMethodCall(path.node)) {
24+
if (result == null) result = path
25+
path.stop()
26+
}
27+
},
28+
Function(path: NodePath<t.Function>) {
29+
path.skip()
30+
},
31+
},
32+
path.state
33+
)
34+
return result
35+
}

src/util/restOfBlockStatement.ts renamed to src/util/removeRestOfBlockStatement.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ import { NodePath } from '@babel/traverse'
33

44
import parentStatement from './parentStatement'
55

6-
export default function restOfBlockStatement(
6+
export default function removeRestOfBlockStatement(
77
path: NodePath<any>
8-
): NodePath<t.Statement>[] {
8+
): t.Statement[] {
99
const statement = parentStatement(path)
1010
const blockStatement = statement.parentPath
1111
if (!blockStatement.isBlockStatement())
@@ -14,5 +14,8 @@ export default function restOfBlockStatement(
1414
const index = body.indexOf(statement)
1515
if (index < 0)
1616
throw new Error('failed to get index of Statement within BlockStatement')
17-
return body.slice(index + 1)
17+
const rest = body.slice(index + 1)
18+
const statements = rest.map(p => p.node)
19+
rest.forEach(p => p.remove())
20+
return statements
1821
}

src/util/unwindCatch.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { NodePath } from '@babel/traverse'
33

44
import getPreceedingLink from './getPreceedingLink'
55
import { isNullish } from './predicates'
6+
import canUnwindAsIs from './canUnwindAsIs'
67
import replaceLink from './replaceLink'
78
import renameBoundIdentifiers from './renameBoundIdentifiers'
89
import unboundIdentifier from './unboundIdentifier'
@@ -39,7 +40,11 @@ export default function unwindCatch(
3940
}
4041
const handlerFunction = handler as NodePath<t.Function>
4142
const body = handlerFunction.get('body')
42-
if (body.isBlockStatement() && !convertConditionalReturns(body)) {
43+
if (
44+
body.isBlockStatement() &&
45+
!canUnwindAsIs(link) &&
46+
!convertConditionalReturns(body)
47+
) {
4348
return getPreceedingLink(link)
4449
}
4550

src/util/unwindPromiseChain.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import * as t from '@babel/types'
22
import { NodePath } from '@babel/traverse'
33

4-
import findAwaitedExpression from './findAwaitedExpression'
54
import getThenHandler from './getThenHandler'
65
import getCatchHandler from './getCatchHandler'
76
import getFinallyHandler from './getFinallyHandler'
@@ -10,6 +9,7 @@ import { unwindThen } from './unwindThen'
109
import unwindFinally from './unwindFinally'
1110
import parentStatement from './parentStatement'
1211
import replaceWithImmediatelyInvokedAsyncArrowFunction from './replaceWithImmediatelyInvokedAsyncArrowFunction'
12+
import findNextLinkToUnwind from './findNextLinkToUnwind'
1313

1414
export default function unwindPromiseChain(
1515
path: NodePath<t.CallExpression>
@@ -24,10 +24,11 @@ export default function unwindPromiseChain(
2424

2525
const { scope } = parentStatement(path)
2626

27-
let link: NodePath<t.Expression> | null = path as any
27+
let link: NodePath<t.CallExpression> | null = path as any
2828

29-
while (link && link.isCallExpression()) {
30-
const callee = (link as NodePath<t.CallExpression>).get('callee')
29+
while (link) {
30+
const origNode = link.node
31+
const callee = link.get('callee')
3132
if (!callee.isMemberExpression()) break
3233

3334
const thenHandler = getThenHandler(link)
@@ -42,7 +43,7 @@ export default function unwindPromiseChain(
4243
} else if (finallyHandler) {
4344
replacements = unwindFinally(finallyHandler)
4445
}
45-
link = replacements ? findAwaitedExpression(replacements) : null
46+
link = replacements ? findNextLinkToUnwind(replacements, origNode) : null
4647
;(scope as any).crawl()
4748
}
4849
}

src/util/unwindThen.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { NodePath } from '@babel/traverse'
44
import getPreceedingLink from './getPreceedingLink'
55
import { awaited } from './builders'
66
import { isNullish } from './predicates'
7+
import canUnwindAsIs from './canUnwindAsIs'
78
import renameBoundIdentifiers from './renameBoundIdentifiers'
89
import hasMutableIdentifiers from './hasMutableIdentifiers'
910
import prependBodyStatement from './prependBodyStatement'
@@ -24,7 +25,11 @@ export function unwindThen(
2425
const handlerFunction = handler as NodePath<t.Function>
2526
const input = handlerFunction.get('params')[0]
2627
const body = handlerFunction.get('body')
27-
if (body.isBlockStatement() && !convertConditionalReturns(body)) {
28+
if (
29+
body.isBlockStatement() &&
30+
!canUnwindAsIs(link) &&
31+
!convertConditionalReturns(body)
32+
) {
2833
return getPreceedingLink(link)
2934
}
3035

0 commit comments

Comments
 (0)