Skip to content

Commit 32eaad5

Browse files
mihailom-dbyhuang-db
authored andcommitted
[SPARK-52233][SQL][FOLLOW-UP] Fix 0.0 and -0.0 issues in map_zip_with
### What changes were proposed in this pull request? This PR introduces normalization of floating point numbers before being inserted into java map of map_zip_with implementation. ### Why are the changes needed? Scala maps have problems with NaNs, but Java maps have problems with 0.0 and -0.0. We need to make sure our implementation works properly for all inputs. ### Does this PR introduce _any_ user-facing change? No, scala implementation did not have this problem and java fix was not released yet. ### How was this patch tested? Golden file test added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51052 from mihailom-db/mapZipWith-followup. Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ca30082 commit 32eaad5

File tree

4 files changed

+107
-63
lines changed

4 files changed

+107
-63
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2929
import org.apache.spark.sql.catalyst.expressions.Cast._
3030
import org.apache.spark.sql.catalyst.expressions.codegen._
31+
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
3132
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike}
3233
import org.apache.spark.sql.catalyst.trees.TreePattern._
3334
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -1147,6 +1148,12 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
11471148
hashMap
11481149
}
11491150

1151+
private lazy val normalizer: Any => Any = keyType match {
1152+
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
1153+
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
1154+
case _ => identity
1155+
}
1156+
11501157
private def getKeysWithIndexesFastAsJava(
11511158
keys1: ArrayData,
11521159
keys2: ArrayData
@@ -1155,7 +1162,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
11551162
for ((z, array) <- Array((0, keys1), (1, keys2))) {
11561163
var i = 0
11571164
while (i < array.numElements()) {
1158-
val key = array.get(i, keyType)
1165+
val key = normalizer(array.get(i, keyType))
11591166
Option(hashMap.get(key)) match {
11601167
case Some(indexes) =>
11611168
if (indexes(z).isEmpty) {

0 commit comments

Comments
 (0)