Skip to content

[BUG] Results of torchopt.Adam and torch.optim.Adam seem to diverge #235

@dilithjay

Description

@dilithjay

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

sys.version: 3.10.13 (main, Sep 18 2023, 17:18:13) [GCC 12.3.1 20230526]
sys.platform: linux
torchopt==0.7.3
torch==2.6.0
functorch==2.6.0

Problem description

Running the provided code shows that there is a slight difference between the torchopt updates and the torch updates (expected there to be none). The difference is larger with higher learning rates.

Reproducible example code

The Python snippets:

import torch
import torchopt

lr = 0.01
torch.manual_seed(1)

model = torch.nn.Linear(3, 1)
optim = torchopt.Adam(model.parameters(), lr=lr)
print(next(model.parameters()))

n_updates = 30
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
    b_x = x[i]
    y = torch.rand((1,), requires_grad=True)
    out = model(b_x)
    loss = ((out - y) ** 2).sum()
    optim.zero_grad()
    loss.backward()
    optim.step()
params_1 = next(model.parameters()).detach()
print(next(model.parameters()))


torch.manual_seed(1)
model = torch.nn.Linear(3, 1)
optim = torch.optim.Adam(model.parameters(), lr=lr)

print(next(model.parameters()))
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
    b_x = x[i]
    y = torch.rand((1,), requires_grad=True)
    out = model(b_x)
    loss = ((out - y) ** 2).sum()
    optim.zero_grad()
    loss.backward()
    optim.step()
params_2 = next(model.parameters()).detach()
print(next(model.parameters()))

print("All close:", torch.allclose(params_1, params_2))
print("Difference:", params_1 - params_2)

# Output:
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184,  0.0238]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184,  0.0238]], requires_grad=True)
# All close: False
# Difference: tensor([[-1.4901e-07, -4.3213e-07, -2.9616e-07]])

Steps to reproduce:

  1. Run provided code.
  2. Larger LRs cause larger divergences.

Traceback

Expected behavior

Expected that there is (almost) zero difference between the result of torchopt and torch

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions