-
Notifications
You must be signed in to change notification settings - Fork 415
Open
Description
Issue Description
Specifying device='mps' results in the following error.
AssertionError: Input device is not valid, please specify 'cuda' or 'cpu'
To Reproduce
from torchsummary import summary
import torch
import torch.nn as nn
class AddNet(nn.Module):
def __init__(self):
super(AddNet, self).__init__()
self.fc1 = nn.Linear(2, 1, bias=False)
def forward(self, x):
return self.fc1(x)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = AddNet().to(device)
# Attempt to use summary with MPS device
summary(model, input_size=(1, 2), device="mps")
Additional Information:
The MPS backend is supported in PyTorch 1.12+ on macOS with Apple Silicon chips (like M1 Macbook)
SUPERustam
Metadata
Metadata
Assignees
Labels
No labels