-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://torchopt.readthedocs.io.
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
What version of TorchOpt are you using?
0.7.3
System information
sys.version: 3.10.13 (main, Sep 18 2023, 17:18:13) [GCC 12.3.1 20230526]
sys.platform: linux
torchopt==0.7.3
torch==2.6.0
functorch==2.6.0
Problem description
Running the provided code shows that there is a slight difference between the torchopt updates and the torch updates (expected there to be none). The difference is larger with higher learning rates.
Reproducible example code
The Python snippets:
import torch
import torchopt
lr = 0.01
torch.manual_seed(1)
model = torch.nn.Linear(3, 1)
optim = torchopt.Adam(model.parameters(), lr=lr)
print(next(model.parameters()))
n_updates = 30
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
b_x = x[i]
y = torch.rand((1,), requires_grad=True)
out = model(b_x)
loss = ((out - y) ** 2).sum()
optim.zero_grad()
loss.backward()
optim.step()
params_1 = next(model.parameters()).detach()
print(next(model.parameters()))
torch.manual_seed(1)
model = torch.nn.Linear(3, 1)
optim = torch.optim.Adam(model.parameters(), lr=lr)
print(next(model.parameters()))
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
b_x = x[i]
y = torch.rand((1,), requires_grad=True)
out = model(b_x)
loss = ((out - y) ** 2).sum()
optim.zero_grad()
loss.backward()
optim.step()
params_2 = next(model.parameters()).detach()
print(next(model.parameters()))
print("All close:", torch.allclose(params_1, params_2))
print("Difference:", params_1 - params_2)
# Output:
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184, 0.0238]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184, 0.0238]], requires_grad=True)
# All close: False
# Difference: tensor([[-1.4901e-07, -4.3213e-07, -2.9616e-07]])
Steps to reproduce:
- Run provided code.
- Larger LRs cause larger divergences.
Traceback
Expected behavior
Expected that there is (almost) zero difference between the result of torchopt and torch
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working