File tree Expand file tree Collapse file tree 2 files changed +19
-2
lines changed
python-package/xgboost/testing
tests/test_distributed/test_with_spark Expand file tree Collapse file tree 2 files changed +19
-2
lines changed Original file line number Diff line number Diff line change
1
+ """Collective module related utilities."""
2
+
3
+ import socket
4
+
5
+
6
+ def get_avail_port () -> int :
7
+ """Returns a port that's available during the function call. It doesn't prevent the
8
+ port from being used after the function returns as we can't reserve the port. The
9
+ utility makes a test more likely to pass.
10
+
11
+ """
12
+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as server :
13
+ server .bind (("127.0.0.1" , 0 ))
14
+ port = server .getsockname ()[1 ]
15
+ return port
Original file line number Diff line number Diff line change 31
31
)
32
32
from xgboost .spark .core import _non_booster_params
33
33
from xgboost .spark .data import pred_contribs
34
+ from xgboost .testing .collective import get_avail_port
34
35
35
36
from .utils import SparkTestCase
36
37
@@ -1772,16 +1773,17 @@ def test_collective_conf(self):
1772
1773
1773
1774
with tempfile .TemporaryDirectory () as tmpdir :
1774
1775
path = "file:" + tmpdir
1776
+ port = get_avail_port ()
1775
1777
classifier = SparkXGBClassifier (
1776
1778
launch_tracker_on_driver = True ,
1777
- coll_cfg = Config (tracker_host_ip = "127.0.0.1" , tracker_port = 58894 ),
1779
+ coll_cfg = Config (tracker_host_ip = "127.0.0.1" , tracker_port = port ),
1778
1780
num_workers = 1 ,
1779
1781
n_estimators = 1 ,
1780
1782
)
1781
1783
1782
1784
def check_conf (conf : Config ) -> None :
1783
1785
assert conf .tracker_host_ip == "127.0.0.1"
1784
- assert conf .tracker_port == 58894
1786
+ assert conf .tracker_port == port
1785
1787
1786
1788
check_conf (classifier .getOrDefault (classifier .coll_cfg ))
1787
1789
classifier .write ().overwrite ().save (path )
You can’t perform that action at this time.
0 commit comments