Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/unittests/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import torch
import torch.distributed as dist

from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure
from unittests import _PATH_ALL_TESTS

_SAMPLE_IMAGE = os.path.join(_PATH_ALL_TESTS, "_data", "image", "i01_01_5.bmp")
Expand All @@ -33,3 +35,37 @@ def cleanup_ddp():
"""Clean up the DDP process group if initialized."""
if dist.is_initialized():
dist.destroy_process_group()


def _run_ssim_ddp(rank: int, world_size: int, free_port: int):
"""Run SSIM metric computation in a DDP setup."""
try:
setup_ddp(rank, world_size, free_port)
device = torch.device(f"cuda:{rank}")
metric = StructuralSimilarityIndexMeasure(reduction="none").to(device)

for _ in range(3):
x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2)
metric.update(x, y)

result = metric.compute()
assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor"
finally:
cleanup_ddp()


def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int):
"""Run MSSSIM metric computation in a DDP setup."""
try:
setup_ddp(rank, world_size, free_port)
device = torch.device(f"cuda:{rank}")
metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device)

for _ in range(3):
x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2)
metric.update(x, y)

result = metric.compute()
assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor"
finally:
cleanup_ddp()
31 changes: 12 additions & 19 deletions tests/unittests/image/test_ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from unittests import NUM_BATCHES, _Input
from unittests._helpers import _IS_WINDOWS, seed_all
from unittests._helpers.testers import MetricTester
from unittests.image import cleanup_ddp, setup_ddp
from unittests.image import _run_ms_ssim_ddp
from unittests.utilities.test_utilities import find_free_port

seed_all(42)
Expand Down Expand Up @@ -110,23 +110,6 @@ def test_ms_ssim_contrast_sensitivity():
assert isinstance(out, torch.Tensor)


def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int):
"""Run MSSSIM metric computation in a DDP setup."""
try:
setup_ddp(rank, world_size, free_port)
device = torch.device(f"cuda:{rank}")
metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device)

for _ in range(3):
x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2)
metric.update(x, y)

result = metric.compute()
assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor"
finally:
cleanup_ddp()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows")
def test_ms_ssim_reduction_none_ddp():
Expand All @@ -139,4 +122,14 @@ def test_ms_ssim_reduction_none_ddp():
free_port = find_free_port()
if free_port == -1:
pytest.skip("No free port available for DDP test.")
mp.spawn(_run_ms_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True)
# Use spawn context to avoid module reimport issues
ctx = mp.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(target=_run_ms_ssim_ddp, args=(rank, world_size, free_port))
p.start()
processes.append(p)

for p in processes:
p.join()
assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}"
31 changes: 12 additions & 19 deletions tests/unittests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from unittests import NUM_BATCHES, _Input
from unittests._helpers import _IS_WINDOWS, seed_all
from unittests._helpers.testers import MetricTester
from unittests.image import cleanup_ddp, setup_ddp
from unittests.image import _run_ssim_ddp
from unittests.utilities.test_utilities import find_free_port

seed_all(42)
Expand Down Expand Up @@ -365,23 +365,6 @@ def test_ssim_for_correct_padding():
assert structural_similarity_index_measure(preds, target) < 1.0


def _run_ssim_ddp(rank: int, world_size: int, free_port: int):
"""Run SSIM metric computation in a DDP setup."""
try:
setup_ddp(rank, world_size, free_port)
device = torch.device(f"cuda:{rank}")
metric = StructuralSimilarityIndexMeasure(reduction="none").to(device)

for _ in range(3):
x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2)
metric.update(x, y)

result = metric.compute()
assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor"
finally:
cleanup_ddp()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows")
def test_ssim_reduction_none_ddp():
Expand All @@ -394,4 +377,14 @@ def test_ssim_reduction_none_ddp():
free_port = find_free_port()
if free_port == -1:
pytest.skip("No free port available for DDP test.")
mp.spawn(_run_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True)
# Use spawn context to avoid module reimport issues
ctx = mp.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(target=_run_ssim_ddp, args=(rank, world_size, free_port))
p.start()
processes.append(p)

for p in processes:
p.join()
assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}"
Loading