17
17
18
18
package org .apache .spark .sql .catalyst .analysis .resolver
19
19
20
- import java .util .IdentityHashMap
21
-
22
- import org .apache .spark .SparkException
23
20
import org .apache .spark .sql .AnalysisException
24
21
import org .apache .spark .sql .catalyst .analysis .{
25
22
AnsiTypeCoercion ,
26
23
CollationTypeCoercion ,
27
24
TypeCoercion
28
25
}
29
- import org .apache .spark .sql .catalyst .expressions .{Alias , Expression , OuterReference , SubExprUtils }
26
+ import org .apache .spark .sql .catalyst .expressions .{Expression , OuterReference , SubExprUtils }
30
27
import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , ListAgg }
31
- import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , Sort }
32
28
import org .apache .spark .sql .catalyst .util .toPrettySQL
33
29
import org .apache .spark .sql .errors .QueryCompilationErrors
34
30
@@ -90,8 +86,6 @@ class AggregateExpressionResolver(
90
86
* 1. Update the [[ExpressionResolver.expressionResolutionContextStack ]];
91
87
* 2. Handle [[OuterReference ]] in [[AggregateExpression ]], if there are any (see
92
88
* `handleOuterAggregateExpression`);
93
- * 3. Handle [[AggregateExpression ]] in [[Sort ]] operator (see
94
- * `handleAggregateExpressionInSort`);
95
89
* - Validation:
96
90
* 1. [[ListAgg ]] is not allowed in DISTINCT aggregates if it contains [[SortOrder ]] different
97
91
* from its child;
@@ -124,12 +118,7 @@ class AggregateExpressionResolver(
124
118
if (expressionResolutionContext.hasOuterReferences) {
125
119
handleOuterAggregateExpression(aggregateExpressionWithChildrenResolved)
126
120
} else {
127
- traversals.current.parentOperator match {
128
- case Sort (_, _, aggregate : Aggregate , _) =>
129
- handleAggregateExpressionInSort(aggregateExpressionWithChildrenResolved, aggregate)
130
- case other =>
131
- aggregateExpressionWithChildrenResolved
132
- }
121
+ aggregateExpressionWithChildrenResolved
133
122
}
134
123
}
135
124
@@ -163,12 +152,15 @@ class AggregateExpressionResolver(
163
152
* - Create a new subtree without [[OuterReference ]]s;
164
153
* - Alias this subtree and put it inside the current [[SubqueryScope ]];
165
154
* - If outer aggregates are allowed, replace the [[AggregateExpression ]] with an
166
- * [[OuterReference ]] to the auto-generated [[Alias ]] that we created. This alias will later
167
- * be injected into the outer [[Aggregate ]]; We store the name that needs to be used for the
168
- * [[OuterReference ]] in [[OuterReference.SINGLE_PASS_SQL_STRING_OVERRIDE ]] computed based on
169
- * the [[AggregateExpression ]] without [[OuterReference ]] pulled out.
155
+ * [[OuterReference ]] to the auto-generated [[Alias ]] that we created in case the subtree
156
+ * without [[OuterReference ]]s can't be found in the outer
157
+ * [[Aggregate.aggregateExpressions ]] list. Otherwise, use the [[Alias ]] from the outer
158
+ * [[Aggregate ]]. This alias will later be injected into the outer [[Aggregate ]];
159
+ * - Store the name that needs to be used for the [[OuterReference ]] in
160
+ * [[OuterReference.SINGLE_PASS_SQL_STRING_OVERRIDE ]] computed based on the
161
+ * [[AggregateExpression ]] without [[OuterReference ]] pulled out.
170
162
* - In case we have an [[AggregateExpression ]] inside a [[Sort ]] operator, we need to handle it
171
- * in a special way (see [[handleAggregateExpressionInSort ]] for more details).
163
+ * in a special way (see [[handleAggregateExpressionOutsideAggregate ]] for more details).
172
164
* - Return the original [[AggregateExpression ]] otherwise. This is done to stay compatible
173
165
* with the fixed-point Analyzer - a proper exception will be thrown later by
174
166
* [[ValidateSubqueryExpression ]].
@@ -183,19 +175,12 @@ class AggregateExpressionResolver(
183
175
}
184
176
185
177
val resolvedOuterAggregateExpression =
186
- if (subqueryRegistry.currentScope.isOuterAggregateAllowed) {
187
- val aggregateExpressionWithStrippedOuterReferences =
188
- SubExprUtils .stripOuterReference(aggregateExpression)
189
-
190
- val outerAggregateExpressionAlias = autoGeneratedAliasProvider.newOuterAlias(
191
- child = aggregateExpressionWithStrippedOuterReferences
192
- )
193
- subqueryRegistry.currentScope.addOuterAggregateExpression(
194
- outerAggregateExpressionAlias,
195
- aggregateExpressionWithStrippedOuterReferences
178
+ if (subqueryRegistry.currentScope.aggregateExpressionsExtractor.isDefined) {
179
+ extractOuterAggregateExpression(
180
+ aggregateExpression = aggregateExpression,
181
+ aggregateExpressionsExtractor =
182
+ subqueryRegistry.currentScope.aggregateExpressionsExtractor.get
196
183
)
197
-
198
- OuterReference (outerAggregateExpressionAlias.toAttribute)
199
184
} else {
200
185
aggregateExpression
201
186
}
@@ -211,41 +196,30 @@ class AggregateExpressionResolver(
211
196
}
212
197
}
213
198
214
- /**
215
- * If we order by an [[AggregateExpression ]] which is not present in the [[Aggregate ]] operator
216
- * (child of the [[Sort ]]) we have to extract it (by adding it to the
217
- * `extractedAggregateExpressionAliases` list of the current expression tree traversal) and add
218
- * it to the [[Aggregate ]] operator afterwards (this is done in the [[SortResolver ]]).
219
- */
220
- private def handleAggregateExpressionInSort (
221
- aggregateExpression : Expression ,
222
- aggregate : Aggregate ): Expression = {
223
- val aliasChildToAliasInAggregateExpressions = new IdentityHashMap [Expression , Alias ]
224
- val aggregateExpressionsSemanticComparator = new SemanticComparator (
225
- aggregate.aggregateExpressions.collect {
226
- case alias : Alias =>
227
- aliasChildToAliasInAggregateExpressions.put(alias.child, alias)
228
- alias.child
229
- }
199
+ private def extractOuterAggregateExpression (
200
+ aggregateExpression : AggregateExpression ,
201
+ aggregateExpressionsExtractor : GroupingAndAggregateExpressionsExtractor ): OuterReference = {
202
+ val aggregateExpressionWithStrippedOuterReferences =
203
+ SubExprUtils .stripOuterReference(aggregateExpression)
204
+
205
+ val outerAggregateExpressionAlias = autoGeneratedAliasProvider.newOuterAlias(
206
+ child = aggregateExpressionWithStrippedOuterReferences
230
207
)
231
208
232
- val referencedAggregateExpression =
233
- aggregateExpressionsSemanticComparator.collectFirst(aggregateExpression)
209
+ val (_, referencedAggregateExpressionAlias) =
210
+ aggregateExpressionsExtractor.collectFirstAggregateExpression(
211
+ aggregateExpressionWithStrippedOuterReferences
212
+ )
234
213
235
- referencedAggregateExpression match {
236
- case Some (expression) =>
237
- aliasChildToAliasInAggregateExpressions.get(expression) match {
238
- case null =>
239
- throw SparkException .internalError(
240
- s " No parent alias for expression $expression while extracting aggregate " +
241
- s " expressions in Sort operator. "
242
- )
243
- case alias : Alias => alias.toAttribute
244
- }
214
+ referencedAggregateExpressionAlias match {
215
+ case Some (alias) =>
216
+ subqueryRegistry.currentScope.addAliasForOuterAggregateExpression(alias)
217
+ OuterReference (alias.toAttribute)
245
218
case None =>
246
- val alias = autoGeneratedAliasProvider.newAlias(child = aggregateExpression)
247
- traversals.current.extractedAggregateExpressionAliases.add(alias)
248
- alias.toAttribute
219
+ subqueryRegistry.currentScope.addAliasForOuterAggregateExpression(
220
+ outerAggregateExpressionAlias
221
+ )
222
+ OuterReference (outerAggregateExpressionAlias.toAttribute)
249
223
}
250
224
}
251
225
0 commit comments