Skip to content

Commit 5ac8990

Browse files
committed
Add missing cases to treeUtils (TypeLambdaTree, Bind, Block, MatchType)
1 parent 3a3d5a3 commit 5ac8990

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

library/src/scala/tasty/reflect/TreeUtils.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ trait TreeUtils
1616
def foldTree(x: X, tree: Tree)(implicit ctx: Context): X
1717
def foldTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X
1818
def foldCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X
19+
def foldTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X
1920
def foldPattern(x: X, tree: Pattern)(implicit ctx: Context): X
2021

2122
def foldTrees(x: X, trees: Iterable[Tree])(implicit ctx: Context): X = (x /: trees)(foldTree)
2223
def foldTypeTrees(x: X, trees: Iterable[TypeOrBoundsTree])(implicit ctx: Context): X = (x /: trees)(foldTypeTree)
2324
def foldCaseDefs(x: X, trees: Iterable[CaseDef])(implicit ctx: Context): X = (x /: trees)(foldCaseDef)
25+
def foldTypeCaseDefs(x: X, trees: Iterable[TypeCaseDef])(implicit ctx: Context): X = (x /: trees)(foldTypeCaseDef)
2426
def foldPatterns(x: X, trees: Iterable[Pattern])(implicit ctx: Context): X = (x /: trees)(foldPattern)
2527
private def foldParents(x: X, trees: Iterable[TermOrTypeTree])(implicit ctx: Context): X = (x /: trees)(foldOverTermOrTypeTree)
2628

@@ -97,13 +99,25 @@ trait TreeUtils
9799
case TypeTree.Applied(tpt, args) => foldTypeTrees(foldTypeTree(x, tpt), args)
98100
case TypeTree.ByName(result) => foldTypeTree(x, result)
99101
case TypeTree.Annotated(arg, annot) => foldTree(foldTypeTree(x, arg), annot)
102+
case TypeTree.TypeLambdaTree(typedefs, arg) => foldTrees(foldTypeTree(x, arg), typedefs)
103+
case TypeTree.Bind(_, tbt) => foldTypeTree(x, tbt)
104+
case TypeTree.Block(typedefs, tpt) => foldTrees(foldTypeTree(x, tpt), typedefs)
105+
case TypeTree.MatchType(boundopt, selector, cases) => {
106+
val bound_fold_result = boundopt.map(foldTypeTree(x, _)).getOrElse(x)
107+
foldTypeCaseDefs(foldTypeTree(bound_fold_result, selector), cases)
108+
}
100109
case TypeBoundsTree(lo, hi) => foldTypeTree(foldTypeTree(x, lo), hi)
101110
}
102111

103112
def foldOverCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X = tree match {
104113
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldPattern(x, pat), guard), body)
105114
}
106115

116+
def foldOverTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X = tree match {
117+
case TypeCaseDef(pat, body) => foldTypeTree(foldTypeTree(x, pat), body)
118+
}
119+
120+
107121
def foldOverPattern(x: X, tree: Pattern)(implicit ctx: Context): X = tree match {
108122
case Pattern.Value(v) => foldTree(x, v)
109123
case Pattern.Bind(_, body) => foldPattern(x, body)
@@ -124,16 +138,19 @@ trait TreeUtils
124138
def traverseTree(tree: Tree)(implicit ctx: Context): Unit = traverseTreeChildren(tree)
125139
def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = traverseTypeTreeChildren(tree)
126140
def traverseCaseDef(tree: CaseDef)(implicit ctx: Context): Unit = traverseCaseDefChildren(tree)
141+
def traverseTypeCaseDef(tree: TypeCaseDef)(implicit ctx: Context): Unit = traverseTypeCaseDefChildren(tree)
127142
def traversePattern(tree: Pattern)(implicit ctx: Context): Unit = traversePatternChildren(tree)
128143

129144
def foldTree(x: Unit, tree: Tree)(implicit ctx: Context): Unit = traverseTree(tree)
130145
def foldTypeTree(x: Unit, tree: TypeOrBoundsTree)(implicit ctx: Context) = traverseTypeTree(tree)
131146
def foldCaseDef(x: Unit, tree: CaseDef)(implicit ctx: Context) = traverseCaseDef(tree)
147+
def foldTypeCaseDef(x: Unit, tree: TypeCaseDef)(implicit ctx: Context) = traverseTypeCaseDef(tree)
132148
def foldPattern(x: Unit, tree: Pattern)(implicit ctx: Context) = traversePattern(tree)
133149

134150
protected def traverseTreeChildren(tree: Tree)(implicit ctx: Context): Unit = foldOverTree((), tree)
135151
protected def traverseTypeTreeChildren(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = foldOverTypeTree((), tree)
136152
protected def traverseCaseDefChildren(tree: CaseDef)(implicit ctx: Context): Unit = foldOverCaseDef((), tree)
153+
protected def traverseTypeCaseDefChildren(tree: TypeCaseDef)(implicit ctx: Context): Unit = foldOverTypeCaseDef((), tree)
137154
protected def traversePatternChildren(tree: Pattern)(implicit ctx: Context): Unit = foldOverPattern((), tree)
138155

139156
}

0 commit comments

Comments
 (0)