Skip to content

Commit 624fb5a

Browse files
committed
Try fix lgbm
1 parent 2750b48 commit 624fb5a

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

mars/learn/contrib/lightgbm/_predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __call__(self):
7878
elif hasattr(self.model, "classes_"):
7979
dtype = np.array(self.model.classes_).dtype
8080
else:
81-
dtype = getattr(self.model, "out_dtype_", np.dtype("float"))
81+
dtype = getattr(self.model, "out_dtype_", [np.dtype("float")])[0]
8282

8383
if self.output_types[0] == OutputType.tensor:
8484
# tensor

mars/learn/contrib/lightgbm/_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,11 @@ def execute(cls, ctx, op: "LGBMTrain"):
406406
op.model_type == LGBMModelType.RANKER
407407
or op.model_type == LGBMModelType.REGRESSOR
408408
):
409-
model.set_params(out_dtype_=np.dtype("float"))
409+
model.set_params(out_dtype_=[np.dtype("float")])
410410
elif hasattr(label_val, "dtype"):
411-
model.set_params(out_dtype_=label_val.dtype)
411+
model.set_params(out_dtype_=[label_val.dtype])
412412
else:
413-
model.set_params(out_dtype_=label_val.dtypes[0])
413+
model.set_params(out_dtype_=[label_val.dtypes[0]])
414414

415415
ctx[op.outputs[0].key] = pickle.dumps(model)
416416
finally:

0 commit comments

Comments
 (0)