Skip to content

Commit 8f699a4

Browse files
committed
[SPARK-52259][ML][CONNECT] Fix Param class binary compatibility
### What changes were proposed in this pull request? This PR fixes the following issue: Integrating third-party ML libraries with Spark 4.0.0 may encounter compatiblity issues due to an interface change. As an example, users may encounter the following error when using xgboost4j with Spark: ``` NoSuchMethodError: 'void org.apache.spark.ml.param.Param.<init>(org.apache.spark.ml.util.Identifiable, java.lang.String, java.lang.String)' ``` ### Why are the changes needed? Fix binary compatibility ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50981 from WeichenXu123/SPARK-52259. Lead-authored-by: Weichen Xu <weichen.xu@databricks.com> Co-authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
1 parent 7fee291 commit 8f699a4

File tree

6 files changed

+59
-32
lines changed

6 files changed

+59
-32
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
191191
* @group expertParam
192192
*/
193193
@Since("2.2.0")
194-
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
195-
"The lower bounds on coefficients if fitting under bound constrained optimization.")
194+
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this.uid, "lowerBoundsOnCoefficients",
195+
"The lower bounds on coefficients if fitting under bound constrained optimization.",
196+
classOf[Matrix]
197+
)
196198

197199
/** @group expertGetParam */
198200
@Since("2.2.0")
@@ -208,8 +210,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
208210
* @group expertParam
209211
*/
210212
@Since("2.2.0")
211-
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
212-
"The upper bounds on coefficients if fitting under bound constrained optimization.")
213+
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this.uid, "upperBoundsOnCoefficients",
214+
"The upper bounds on coefficients if fitting under bound constrained optimization.",
215+
classOf[Matrix]
216+
)
213217

214218
/** @group expertGetParam */
215219
@Since("2.2.0")
@@ -224,8 +228,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
224228
* @group expertParam
225229
*/
226230
@Since("2.2.0")
227-
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
228-
"The lower bounds on intercepts if fitting under bound constrained optimization.")
231+
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this.uid, "lowerBoundsOnIntercepts",
232+
"The lower bounds on intercepts if fitting under bound constrained optimization.",
233+
classOf[Vector]
234+
)
229235

230236
/** @group expertGetParam */
231237
@Since("2.2.0")
@@ -240,8 +246,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
240246
* @group expertParam
241247
*/
242248
@Since("2.2.0")
243-
val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
244-
"The upper bounds on intercepts if fitting under bound constrained optimization.")
249+
val upperBoundsOnIntercepts: Param[Vector] = new Param(this.uid, "upperBoundsOnIntercepts",
250+
"The upper bounds on intercepts if fitting under bound constrained optimization.",
251+
classOf[Vector]
252+
)
245253

246254
/** @group expertGetParam */
247255
@Since("2.2.0")

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ private[classification] trait MultilayerPerceptronParams extends ProbabilisticCl
7474
* @group expertParam
7575
*/
7676
@Since("2.0.0")
77-
final val initialWeights: Param[Vector] = new Param[Vector](this, "initialWeights",
78-
"The initial weights of the model")
77+
final val initialWeights: Param[Vector] = new Param[Vector](this.uid, "initialWeights",
78+
"The initial weights of the model", classOf[Vector])
7979

8080
/** @group expertGetParam */
8181
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
4343
* @group param
4444
*/
4545
@Since("2.0.0")
46-
val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
46+
val scalingVec: Param[Vector] = new Param(
47+
this.uid, "scalingVec", "vector for hadamard product", classOf[Vector]
48+
)
4749

4850
/** @group setParam */
4951
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import java.util.{List => JList}
2323
import scala.annotation.varargs
2424
import scala.collection.mutable
2525
import scala.jdk.CollectionConverters._
26-
import scala.reflect.ClassTag
2726

2827
import org.json4s._
2928
import org.json4s.jackson.JsonMethods._
@@ -46,20 +45,23 @@ import org.apache.spark.util.SizeEstimator
4645
* See [[ParamValidators]] for factory methods for common validation functions.
4746
* @tparam T param value type
4847
*/
49-
class Param[T: ClassTag](
50-
val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
51-
extends Serializable {
48+
class Param[T](
49+
val parent: String, val name: String, val doc: String, val isValid: T => Boolean,
50+
val dataClass: Class[T]
51+
) extends Serializable {
5252

53-
// Spark Connect ML needs T type information which has been erased when compiling,
54-
// Use classTag to preserve the T type.
55-
val paramValueClassTag = implicitly[ClassTag[T]]
53+
def this(parent: String, name: String, doc: String, isValid: T => Boolean) =
54+
this(parent, name, doc, isValid, null)
5655

5756
def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
5857
this(parent.uid, name, doc, isValid)
5958

6059
def this(parent: String, name: String, doc: String) =
6160
this(parent, name, doc, ParamValidators.alwaysTrue[T])
6261

62+
def this(parent: String, name: String, doc: String, dataClass: Class[T]) =
63+
this(parent, name, doc, ParamValidators.alwaysTrue[T], dataClass)
64+
6365
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
6466

6567
/**
@@ -329,7 +331,7 @@ object ParamValidators {
329331
* Specialized version of `Param[Double]` for Java.
330332
*/
331333
class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean)
332-
extends Param[Double](parent, name, doc, isValid) {
334+
extends Param[Double](parent, name, doc, isValid, classOf[Double]) {
333335

334336
def this(parent: String, name: String, doc: String) =
335337
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -387,7 +389,7 @@ private[param] object DoubleParam {
387389
* Specialized version of `Param[Int]` for Java.
388390
*/
389391
class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
390-
extends Param[Int](parent, name, doc, isValid) {
392+
extends Param[Int](parent, name, doc, isValid, classOf[Int]) {
391393

392394
def this(parent: String, name: String, doc: String) =
393395
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -414,7 +416,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
414416
* Specialized version of `Param[Float]` for Java.
415417
*/
416418
class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean)
417-
extends Param[Float](parent, name, doc, isValid) {
419+
extends Param[Float](parent, name, doc, isValid, classOf[Float]) {
418420

419421
def this(parent: String, name: String, doc: String) =
420422
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -473,7 +475,7 @@ private object FloatParam {
473475
* Specialized version of `Param[Long]` for Java.
474476
*/
475477
class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean)
476-
extends Param[Long](parent, name, doc, isValid) {
478+
extends Param[Long](parent, name, doc, isValid, classOf[Long]) {
477479

478480
def this(parent: String, name: String, doc: String) =
479481
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -500,7 +502,8 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
500502
* Specialized version of `Param[Boolean]` for Java.
501503
*/
502504
class BooleanParam(parent: String, name: String, doc: String) // No need for isValid
503-
extends Param[Boolean](parent, name, doc) {
505+
extends Param[Boolean](parent, name, doc, ParamValidators.alwaysTrue[Boolean], classOf[Boolean])
506+
{
504507

505508
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
506509

@@ -521,7 +524,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
521524
* Specialized version of `Param[Array[String]]` for Java.
522525
*/
523526
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
524-
extends Param[Array[String]](parent, name, doc, isValid) {
527+
extends Param[Array[String]](parent.uid, name, doc, isValid, classOf[Array[String]]) {
525528

526529
def this(parent: Params, name: String, doc: String) =
527530
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -544,7 +547,7 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
544547
* Specialized version of `Param[Array[Double]]` for Java.
545548
*/
546549
class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
547-
extends Param[Array[Double]](parent, name, doc, isValid) {
550+
extends Param[Array[Double]](parent.uid, name, doc, isValid, classOf[Array[Double]]) {
548551

549552
def this(parent: Params, name: String, doc: String) =
550553
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -576,7 +579,9 @@ class DoubleArrayArrayParam(
576579
name: String,
577580
doc: String,
578581
isValid: Array[Array[Double]] => Boolean)
579-
extends Param[Array[Array[Double]]](parent, name, doc, isValid) {
582+
extends Param[Array[Array[Double]]](
583+
parent.uid, name, doc, isValid, classOf[Array[Array[Double]]]
584+
) {
580585

581586
def this(parent: Params, name: String, doc: String) =
582587
this(parent, name, doc, ParamValidators.alwaysTrue)
@@ -610,7 +615,7 @@ class DoubleArrayArrayParam(
610615
* Specialized version of `Param[Array[Int]]` for Java.
611616
*/
612617
class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
613-
extends Param[Array[Int]](parent, name, doc, isValid) {
618+
extends Param[Array[Int]](parent.uid, name, doc, isValid, classOf[Array[Int]]) {
614619

615620
def this(parent: Params, name: String, doc: String) =
616621
this(parent, name, doc, ParamValidators.alwaysTrue)

mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.List;
2222

2323
import org.apache.spark.ml.util.Identifiable$;
24-
import scala.reflect.ClassTag;
2524

2625
/**
2726
* A subclass of Params for testing.
@@ -111,7 +110,7 @@ private void init() {
111110
ParamValidators.inRange(0.0, 1.0));
112111
List<String> validStrings = Arrays.asList("a", "b");
113112
myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
114-
ParamValidators.inArray(validStrings), ClassTag.apply(String.class));
113+
ParamValidators.inArray(validStrings));
115114
myDoubleArrayParam_ =
116115
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
117116

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,22 @@ private[ml] object MLUtils {
157157
}
158158

159159
case _ =>
160-
reconcileParam(
161-
p.paramValueClassTag.runtimeClass,
162-
LiteralValueProtoConverter.toCatalystValue(literal))
160+
val paramValue = LiteralValueProtoConverter.toCatalystValue(literal)
161+
val paramType: Class[_] = if (p.dataClass == null) {
162+
if (paramValue.isInstanceOf[String]) {
163+
classOf[String]
164+
} else if (paramValue.isInstanceOf[Boolean]) {
165+
classOf[Boolean]
166+
} else {
167+
throw MlUnsupportedException(
168+
"Spark Connect ML requires the customized ML Param class setting 'dataClass' " +
169+
"parameter if the param value type is not String or Boolean type, " +
170+
s"but the param $name does not have the required dataClass.")
171+
}
172+
} else {
173+
p.dataClass
174+
}
175+
reconcileParam(paramType, paramValue)
163176
}
164177
instance.set(p, value)
165178
}

0 commit comments

Comments
 (0)