diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index cf46bd43f550..0db91491ee27 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -218,8 +218,13 @@ def check_extmem_qdm( ) booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8) - X, y, w = it.as_arrays() - Xy = xgb.QuantileDMatrix(X, y, weight=w) + it = tm.IteratorForTest( + *tm.make_batches( + n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu" + ), + cache=None, + ) + Xy = xgb.QuantileDMatrix(it) booster = xgb.train({"device": device}, Xy, num_boost_round=8) if device == "cpu": diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 95074553acd7..0fdf024c2b38 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -34,44 +34,48 @@ def test_socket_error(): tracker.free() -def run_rabit_ops(client, n_workers): - from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args - - workers = tm.get_client_workers(client) - rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) - assert not collective.is_distributed() - n_workers_from_dask = len(workers) - assert n_workers == n_workers_from_dask +def run_rabit_ops(pool, n_workers: int, address: str) -> None: + tracker = RabitTracker(host_ip=address, n_workers=n_workers) + tracker.start() + args = tracker.worker_args() - def local_test(worker_id): - with CommunicatorContext(**rabit_args): + def local_test(worker_id: int, rabit_args: dict) -> int: + with collective.CommunicatorContext(**rabit_args): a = 1 assert collective.is_distributed() - a = np.array([a]) - reduced = collective.allreduce(a, collective.Op.SUM) + arr = np.array([a]) + reduced = collective.allreduce(arr, collective.Op.SUM) assert reduced[0] == n_workers - worker_id = np.array([worker_id]) - reduced = collective.allreduce(worker_id, collective.Op.MAX) + arr = np.array([worker_id]) + reduced = collective.allreduce(arr, collective.Op.MAX) assert reduced == n_workers - 1 return 1 - futures = client.map(local_test, range(len(workers)), workers=workers) - results = client.gather(futures) + fn = update_wrapper(partial(local_test, rabit_args=args), local_test) + results = pool.map(fn, range(n_workers)) assert sum(results) == n_workers -@pytest.mark.skipif(**tm.no_dask()) +@pytest.mark.skipif(**tm.no_loky()) def test_rabit_ops(): - from distributed import Client, LocalCluster + from loky import get_reusable_executor - n_workers = 3 - with LocalCluster(n_workers=n_workers) as cluster: - with Client(cluster) as client: - run_rabit_ops(client, n_workers) + n_workers = 4 + with get_reusable_executor(max_workers=n_workers) as pool: + run_rabit_ops(pool, n_workers, "127.0.0.1") +@pytest.mark.skipif(**tm.no_ipv6()) +@pytest.mark.skipif(**tm.no_loky()) +def test_rabit_ops_ipv6(): + from loky import get_reusable_executor + + n_workers = 4 + with get_reusable_executor(max_workers=n_workers) as pool: + run_rabit_ops(pool, n_workers, "::1") + def run_allreduce(pool, n_workers: int) -> None: tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers) @@ -133,19 +137,6 @@ def test_broadcast(): run_broadcast(pool, n_workers) -@pytest.mark.skipif(**tm.no_ipv6()) -@pytest.mark.skipif(**tm.no_dask()) -def test_rabit_ops_ipv6(): - import dask - from distributed import Client, LocalCluster - - n_workers = 3 - with dask.config.set({"xgboost.scheduler_address": "[::1]"}): - with LocalCluster(n_workers=n_workers, host="[::1]") as cluster: - with Client(cluster) as client: - run_rabit_ops(client, n_workers) - - @pytest.mark.skipif(**tm.no_dask()) def test_rank_assignment() -> None: from distributed import Client, LocalCluster