diff --git a/tests/unittests/image/__init__.py b/tests/unittests/image/__init__.py index 8eea7d284b8..a31775c0bde 100644 --- a/tests/unittests/image/__init__.py +++ b/tests/unittests/image/__init__.py @@ -16,6 +16,8 @@ import torch import torch.distributed as dist +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure from unittests import _PATH_ALL_TESTS _SAMPLE_IMAGE = os.path.join(_PATH_ALL_TESTS, "_data", "image", "i01_01_5.bmp") @@ -33,3 +35,37 @@ def cleanup_ddp(): """Clean up the DDP process group if initialized.""" if dist.is_initialized(): dist.destroy_process_group() + + +def _run_ssim_ddp(rank: int, world_size: int, free_port: int): + """Run SSIM metric computation in a DDP setup.""" + try: + setup_ddp(rank, world_size, free_port) + device = torch.device(f"cuda:{rank}") + metric = StructuralSimilarityIndexMeasure(reduction="none").to(device) + + for _ in range(3): + x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) + metric.update(x, y) + + result = metric.compute() + assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" + finally: + cleanup_ddp() + + +def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int): + """Run MSSSIM metric computation in a DDP setup.""" + try: + setup_ddp(rank, world_size, free_port) + device = torch.device(f"cuda:{rank}") + metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device) + + for _ in range(3): + x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) + metric.update(x, y) + + result = metric.compute() + assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" + finally: + cleanup_ddp() diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 89b9b5778bd..dc779fd9510 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -23,7 +23,7 @@ from unittests import NUM_BATCHES, _Input from unittests._helpers import _IS_WINDOWS, seed_all from unittests._helpers.testers import MetricTester -from unittests.image import cleanup_ddp, setup_ddp +from unittests.image import _run_ms_ssim_ddp from unittests.utilities.test_utilities import find_free_port seed_all(42) @@ -110,23 +110,6 @@ def test_ms_ssim_contrast_sensitivity(): assert isinstance(out, torch.Tensor) -def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int): - """Run MSSSIM metric computation in a DDP setup.""" - try: - setup_ddp(rank, world_size, free_port) - device = torch.device(f"cuda:{rank}") - metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device) - - for _ in range(3): - x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) - metric.update(x, y) - - result = metric.compute() - assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" - finally: - cleanup_ddp() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows") def test_ms_ssim_reduction_none_ddp(): @@ -139,4 +122,14 @@ def test_ms_ssim_reduction_none_ddp(): free_port = find_free_port() if free_port == -1: pytest.skip("No free port available for DDP test.") - mp.spawn(_run_ms_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True) + # Use spawn context to avoid module reimport issues + ctx = mp.get_context("spawn") + processes = [] + for rank in range(world_size): + p = ctx.Process(target=_run_ms_ssim_ddp, args=(rank, world_size, free_port)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index e327e7d7f70..b1040ba911d 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -26,7 +26,7 @@ from unittests import NUM_BATCHES, _Input from unittests._helpers import _IS_WINDOWS, seed_all from unittests._helpers.testers import MetricTester -from unittests.image import cleanup_ddp, setup_ddp +from unittests.image import _run_ssim_ddp from unittests.utilities.test_utilities import find_free_port seed_all(42) @@ -365,23 +365,6 @@ def test_ssim_for_correct_padding(): assert structural_similarity_index_measure(preds, target) < 1.0 -def _run_ssim_ddp(rank: int, world_size: int, free_port: int): - """Run SSIM metric computation in a DDP setup.""" - try: - setup_ddp(rank, world_size, free_port) - device = torch.device(f"cuda:{rank}") - metric = StructuralSimilarityIndexMeasure(reduction="none").to(device) - - for _ in range(3): - x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) - metric.update(x, y) - - result = metric.compute() - assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" - finally: - cleanup_ddp() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows") def test_ssim_reduction_none_ddp(): @@ -394,4 +377,14 @@ def test_ssim_reduction_none_ddp(): free_port = find_free_port() if free_port == -1: pytest.skip("No free port available for DDP test.") - mp.spawn(_run_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True) + # Use spawn context to avoid module reimport issues + ctx = mp.get_context("spawn") + processes = [] + for rank in range(world_size): + p = ctx.Process(target=_run_ssim_ddp, args=(rank, world_size, free_port)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}"