'Jax / Neural Tangents `linearize` induces CUDA_ERROR_OUT_OF_MEMORY

Cross-posting from GitHub: https://github.com/google/neural-tangents/issues/144

We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.

We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still the pre-linearized model to train. However, when we try using the linearized model, we get:

RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory

This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).

We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):

Their code:

def make_update_fn(*, apply_fn, accum_steps, lr_fn):
  """Returns update step for data parallel training."""

  def update_fn(opt, step, batch, rng):

    _, new_rng = jax.random.split(rng)
    # Bind the rng key to the device id (which is unique across hosts)
    # Note: This is only used for multi-host training (i.e. multiple computers
    # each with multiple accelerators).
    dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))

    def cross_entropy_loss(*, logits, labels):
      logp = jax.nn.log_softmax(logits)
      return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
      logits = apply_fn(
          dict(params=params),
          rngs=dict(dropout=dropout_rng),
          inputs=images,
          train=True)
      return cross_entropy_loss(logits=logits, labels=labels)

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        accum_steps)
    g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
    l = jax.lax.pmean(l, axis_name='batch')

    opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
    return opt, l, new_rng

  return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))

That function is then called via:

# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
    apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

The training loop where the memory error arises:

losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
    tqdm.trange(1, total_steps + 1),
    ds_train.as_numpy_iterator(),
):

  opt_repl, loss_repl, update_rng_repl = update_fn_repl(
      opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)  # ERROR IS HERE
  losses.append(loss_repl[0])
  lrs.append(lr_fn(step))

In order to linearize the ViT, we do the following:

def vit_apply(params, input):
  return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source