'Jax / Neural Tangents `linearize` induces CUDA_ERROR_OUT_OF_MEMORY
Cross-posting from GitHub: https://github.com/google/neural-tangents/issues/144
We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.
We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still the pre-linearized model to train. However, when we try using the linearized model, we get:
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).
We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):
Their code:
def make_update_fn(*, apply_fn, accum_steps, lr_fn):
"""Returns update step for data parallel training."""
def update_fn(opt, step, batch, rng):
_, new_rng = jax.random.split(rng)
# Bind the rng key to the device id (which is unique across hosts)
# Note: This is only used for multi-host training (i.e. multiple computers
# each with multiple accelerators).
dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
def cross_entropy_loss(*, logits, labels):
logp = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(logp * labels, axis=1))
def loss_fn(params, images, labels):
logits = apply_fn(
dict(params=params),
rngs=dict(dropout=dropout_rng),
inputs=images,
train=True)
return cross_entropy_loss(logits=logits, labels=labels)
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
accum_steps)
g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
l = jax.lax.pmean(l, axis_name='batch')
opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
return opt, l, new_rng
return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))
That function is then called via:
# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)
The training loop where the memory error arises:
losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(),
):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) # ERROR IS HERE
losses.append(loss_repl[0])
lrs.append(lr_fn(step))
In order to linearize the ViT, we do the following:
def vit_apply(params, input):
return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
