Skip to content

Commit 850d15a

Browse files
authored
Fix mypy errors. (#11389)
1 parent c431916 commit 850d15a

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

python-package/xgboost/dask/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import logging
5656
from collections import defaultdict
5757
from contextlib import contextmanager
58-
from functools import partial, update_wrapper, wraps
58+
from functools import partial, update_wrapper
5959
from threading import Thread
6060
from typing import (
6161
Any,
@@ -354,7 +354,7 @@ def __init__(
354354
label_upper_bound=label_upper_bound,
355355
)
356356

357-
def __await__(self) -> Generator:
357+
def __await__(self) -> Generator[None, None, "DaskDMatrix"]:
358358
return self._init.__await__()
359359

360360
async def _map_local_data(
@@ -1490,7 +1490,7 @@ async def _predict_async(
14901490
if isinstance(predts, dd.DataFrame):
14911491
predts = predts.to_dask_array()
14921492
else:
1493-
test_dmatrix: DaskDMatrix = await DaskDMatrix( # type: ignore
1493+
test_dmatrix: DaskDMatrix = await DaskDMatrix(
14941494
self.client,
14951495
data=data,
14961496
base_margin=base_margin,
@@ -1532,7 +1532,7 @@ async def _apply_async(
15321532
iteration_range: Optional[IterationRange] = None,
15331533
) -> Any:
15341534
iteration_range = self._get_iteration_range(iteration_range)
1535-
test_dmatrix: DaskDMatrix = await DaskDMatrix( # type: ignore
1535+
test_dmatrix: DaskDMatrix = await DaskDMatrix(
15361536
self.client,
15371537
data=X,
15381538
missing=self.missing,

tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -670,9 +670,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
670670
X = X.to_backend("cupy")
671671
y = y.to_backend("cupy")
672672

673-
m: dxgb.DaskDMatrix = await dxgb.DaskQuantileDMatrix(
674-
client, X, y
675-
) # type:ignore
673+
m: dxgb.DaskDMatrix = await dxgb.DaskQuantileDMatrix(client, X, y)
676674
output = await dxgb.train(
677675
client, {"tree_method": "hist", "device": "cuda"}, dtrain=m
678676
)

tests/test_distributed/test_with_dask/test_with_dask.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def test_empty_dmatrix(tree_method) -> None:
953953
async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainReturnT:
954954
async with Client(scheduler_address, asynchronous=True) as client:
955955
X, y, _ = generate_array()
956-
m = await DaskDMatrix(client, X, y) # type: ignore
956+
m = await DaskDMatrix(client, X, y)
957957
output = await dxgb.train(client, {}, dtrain=m)
958958

959959
with_m = await dxgb.predict(client, output, m)
@@ -1058,8 +1058,8 @@ async def train() -> None:
10581058
) as cluster:
10591059
async with Client(cluster, asynchronous=True) as client:
10601060
X, y, w = generate_array(with_weights=True)
1061-
dtrain = await DaskDMatrix(client, X, y, weight=w) # type: ignore
1062-
dvalid = await DaskDMatrix(client, X, y, weight=w) # type: ignore
1061+
dtrain = await DaskDMatrix(client, X, y, weight=w)
1062+
dvalid = await DaskDMatrix(client, X, y, weight=w)
10631063
output = await dxgb.train(client, {}, dtrain=dtrain)
10641064
await dxgb.predict(client, output, data=dvalid)
10651065

@@ -2195,7 +2195,7 @@ async def test_worker_left(c: Client, s: Scheduler, a: Worker, b: Worker):
21952195
async with Worker(s.address):
21962196
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
21972197
dy = da.random.random((1000,)).rechunk(chunks=(10,))
2198-
d_train = await dxgb.DaskDMatrix( # type: ignore
2198+
d_train = await dxgb.DaskDMatrix(
21992199
c,
22002200
dx,
22012201
dy,

0 commit comments

Comments
 (0)