Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
This repository was archived by the owner on May 6, 2025. It is now read-only.

High Memory Usage on Infinite-Width NTK for GPU Only #204

@deoliveirajoshua

Description

@deoliveirajoshua

Hello,

I implemented a brutally simple infinite-width model, calling the kernel_fn with a batch of a single vector.

When I run this on CPU, I don't run into any exorbitant memory issues.

However, when I run this on an A100 GPU, it allocates just under 60GB after calling this tiny calculation!

This also happens on the GPU only when infinite-width CNNs are used on image datasets (CIFAR, MNIST, etc.)

Does anyone know what could be causing this to happen?

import numpy as np
import neural_tangents as nt

print(jax.devices())

def linear_model():
    return nt.stax.serial(
        nt.stax.Dense(512), nt.stax.Relu(),
        nt.stax.Dense(512), nt.stax.Relu(),
        nt.stax.Dense(1)
    )

init_fn, apply_fn, kernel_fn = linear_model()

total = 1
X = np.ones((total, 200), dtype=np.float32)

!nvidia-smi
ntk = kernel_fn(X, None, 'ntk')
!nvidia-smi
print(ntk)

Usage from first SMI call: 426MiB / 81920MiB
Usage from second SMI call: 61352MiB / 81920MiB

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions