Skip to content

Commit 14dbc7d

Browse files
mihailom-dbyhuang-db
authored andcommitted
[SPARK-52233][SQL] Fix map_zip_with for Floating Point Types
### What changes were proposed in this pull request? Fix to `map_zip_with` expression while handling floating point numbers. ### Why are the changes needed? Previously we would run `getKeysWithIndexesFast` which would use faulty scala.collections.mutable.LinkedHashMap implementation, which does not use proper equality on keys for floating point numbers. All NaNs would be treated in a different way. This PR aims to fix this behaviour, by using `java.utils.LinkedHashMap` instead, which uses boxed `Type.equals()` instead of primitive type equality `==`. Example: ``` select map_zip_with(map(float('NaN'), 1), map(float('NaN'), 2), (k, v1, v2) -> (v1, v2)) ``` Output before: ``` {"NaN":{"v1":1,"v2":null},"NaN":{"v1":null,"v2":2}} ``` Output after: ``` {"NaN":{"v1":1,"v2":2}} ``` ### Does this PR introduce _any_ user-facing change? Yes, fixing the way expression works. ### How was this patch tested? Added tests to golden files for both edge cases, `NaN` and `Infinity`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50967 from mihailom-db/FixMapZipWith. Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 5b8ae08 commit 14dbc7d

File tree

5 files changed

+170
-57
lines changed

5 files changed

+170
-57
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.Comparator
2121
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
2222

2323
import scala.collection.mutable
24+
import scala.jdk.CollectionConverters.MapHasAsScala
2425

2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
@@ -1109,8 +1110,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
11091110
*/
11101111
@transient private lazy val getKeysWithValueIndexes:
11111112
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
1112-
if (TypeUtils.typeWithProperEquals(keyType)) {
1113-
getKeysWithIndexesFast
1113+
if (TypeUtils.typeWithProperEquals(keyType) && SQLConf.get.mapZipWithUsesJavaCollections) {
1114+
getKeysWithIndexesFastAsJava
1115+
} else if (TypeUtils.typeWithProperEquals(keyType)) {
1116+
getKeysWithIndexesFastUsingScala
11141117
} else {
11151118
getKeysWithIndexesBruteForce
11161119
}
@@ -1122,7 +1125,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
11221125
}
11231126
}
11241127

1125-
private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = {
1128+
private def getKeysWithIndexesFastUsingScala(keys1: ArrayData, keys2: ArrayData) = {
11261129
val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]]
11271130
for ((z, array) <- Array((0, keys1), (1, keys2))) {
11281131
var i = 0
@@ -1144,6 +1147,31 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
11441147
hashMap
11451148
}
11461149

1150+
private def getKeysWithIndexesFastAsJava(
1151+
keys1: ArrayData,
1152+
keys2: ArrayData
1153+
): scala.collection.mutable.LinkedHashMap[Any, Array[Option[Int]]] = {
1154+
val hashMap = new java.util.LinkedHashMap[Any, Array[Option[Int]]]
1155+
for ((z, array) <- Array((0, keys1), (1, keys2))) {
1156+
var i = 0
1157+
while (i < array.numElements()) {
1158+
val key = array.get(i, keyType)
1159+
Option(hashMap.get(key)) match {
1160+
case Some(indexes) =>
1161+
if (indexes(z).isEmpty) {
1162+
indexes(z) = Some(i)
1163+
}
1164+
case None =>
1165+
val indexes = Array[Option[Int]](None, None)
1166+
indexes(z) = Some(i)
1167+
hashMap.put(key, indexes)
1168+
}
1169+
i += 1
1170+
}
1171+
}
1172+
scala.collection.mutable.LinkedHashMap(hashMap.asScala.toSeq: _*)
1173+
}
1174+
11471175
private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = {
11481176
val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
11491177
for ((z, array) <- Array((0, keys1), (1, keys2))) {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,14 @@ object SQLConf {
11091109
.stringConf
11101110
.createOptional
11111111

1112+
val MAP_ZIP_WITH_USES_JAVA_COLLECTIONS =
1113+
buildConf("spark.sql.mapZipWithUsesJavaCollections")
1114+
.doc("When true, the `map_zip_with` function uses Java collections instead of Scala " +
1115+
"collections. This is useful for avoiding NaN equality issues.")
1116+
.version("4.1.0")
1117+
.booleanConf
1118+
.createWithDefault(true)
1119+
11121120
val SUBEXPRESSION_ELIMINATION_ENABLED =
11131121
buildConf("spark.sql.subexpressionElimination.enabled")
11141122
.internal()
@@ -6383,6 +6391,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
63836391
*/
63846392
def hintErrorHandler: HintErrorHandler = HintErrorLogger
63856393

6394+
def mapZipWithUsesJavaCollections: Boolean =
6395+
getConf(MAP_ZIP_WITH_USES_JAVA_COLLECTIONS)
6396+
63866397
def subexpressionEliminationEnabled: Boolean =
63876398
getConf(SUBEXPRESSION_ELIMINATION_ENABLED)
63886399

0 commit comments

Comments
 (0)