Skip to content

Commit dde2772

Browse files
authored
Reduce flakiness of the pyspark conf test. (#11456)
1 parent afe556a commit dde2772

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Collective module related utilities."""
2+
3+
import socket
4+
5+
6+
def get_avail_port() -> int:
7+
"""Returns a port that's available during the function call. It doesn't prevent the
8+
port from being used after the function returns as we can't reserve the port. The
9+
utility makes a test more likely to pass.
10+
11+
"""
12+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
13+
server.bind(("127.0.0.1", 0))
14+
port = server.getsockname()[1]
15+
return port

tests/test_distributed/test_with_spark/test_spark_local.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from xgboost.spark.core import _non_booster_params
3333
from xgboost.spark.data import pred_contribs
34+
from xgboost.testing.collective import get_avail_port
3435

3536
from .utils import SparkTestCase
3637

@@ -1772,16 +1773,17 @@ def test_collective_conf(self):
17721773

17731774
with tempfile.TemporaryDirectory() as tmpdir:
17741775
path = "file:" + tmpdir
1776+
port = get_avail_port()
17751777
classifier = SparkXGBClassifier(
17761778
launch_tracker_on_driver=True,
1777-
coll_cfg=Config(tracker_host_ip="127.0.0.1", tracker_port=58894),
1779+
coll_cfg=Config(tracker_host_ip="127.0.0.1", tracker_port=port),
17781780
num_workers=1,
17791781
n_estimators=1,
17801782
)
17811783

17821784
def check_conf(conf: Config) -> None:
17831785
assert conf.tracker_host_ip == "127.0.0.1"
1784-
assert conf.tracker_port == 58894
1786+
assert conf.tracker_port == port
17851787

17861788
check_conf(classifier.getOrDefault(classifier.coll_cfg))
17871789
classifier.write().overwrite().save(path)

0 commit comments

Comments
 (0)