'Huggingface T5-base with Seq2SeqTrainer RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu

I am trying to this huggingface example: https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb

in my own system with 2 GPUs with my own data that I load as a Huggingface Datasets dataset:

dataset = Dataset.from_pandas(df)

model_name = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name)

max_input_length = 256
max_target_length = 128

def preprocess_function(examples):
    model_inputs = tokenizer(examples["text"], max_length=max_input_length, padding=True, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["label"], max_length=max_target_length, padding=True, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(["text", "label"])

train_test = tokenized_datasets.train_test_split(test_size=0.2)
tokenized_datasets_split = DatasetDict({
    'train': train_test['train'],
    'test': train_test['test']})

train_dataset = tokenized_datasets_split["train"].shuffle(seed=42)
test_dataset = tokenized_datasets_split["test"].shuffle(seed=42)

and am trying to fine tune a t5-base with this data:

args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

Training happens fine, but as soon as eval_mode is set I get an error:

     46     return {k: round(v, 4) for k, v in result.items()}
     48 trainer = Seq2SeqTrainer(
     49     model=model,
     50     args=args,
   (...)
     55     compute_metrics=compute_metrics,
     56 )
---> 58 trainer.train()

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:1391, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1388         break
   1390 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 1391 self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1393 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   1394     if is_torch_tpu_available():
   1395         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:1491, in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1489 metrics = None
   1490 if self.control.should_evaluate:
-> 1491     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   1492     self._report_to_hp_search(trial, epoch, metrics)
   1494 if self.control.should_save:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:75, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, max_length, num_beams)
     73 self._max_length = max_length if max_length is not None else self.args.generation_max_length
     74 self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
---> 75 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:2113, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   2110 start_time = time.time()
   2112 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 2113 output = eval_loop(
   2114     eval_dataloader,
   2115     description="Evaluation",
   2116     # No point gathering the predictions if there are no metrics, otherwise we defer to
   2117     # self.args.prediction_loss_only
   2118     prediction_loss_only=True if self.compute_metrics is None else None,
   2119     ignore_keys=ignore_keys,
   2120     metric_key_prefix=metric_key_prefix,
   2121 )
   2123 total_batch_size = self.args.eval_batch_size * self.args.world_size
   2124 output.metrics.update(
   2125     speed_metrics(
   2126         metric_key_prefix,
   (...)
   2130     )
   2131 )

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer.py:2285, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   2282         batch_size = observed_batch_size
   2284 # Prediction step
-> 2285 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   2287 # Update containers on host
   2288 if loss is not None:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:167, in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
    160 # XXX: adapt synced_gpus for fairscale as well
    161 gen_kwargs = {
    162     "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
    163     "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
    164     "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
    165 }
--> 167 generated_tokens = self.model.generate(
    168     inputs["input_ids"],
    169     attention_mask=inputs["attention_mask"],
    170     **gen_kwargs,
    171 )
    172 # in case the batch is shorter than max length, the output should be padded
    173 if generated_tokens.shape[-1] < gen_kwargs["max_length"]:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/generation_utils.py:922, in GenerationMixin.generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
    918 encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
    920 if self.config.is_encoder_decoder:
    921     # add encoder_outputs to model_kwargs
--> 922     model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
    924     # set input_ids as decoder_input_ids
    925     if "decoder_input_ids" in model_kwargs:

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/generation_utils.py:417, in GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs)
    411     encoder = self.get_encoder()
    412     encoder_kwargs = {
    413         argument: value
    414         for argument, value in model_kwargs.items()
    415         if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
    416     }
--> 417     model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
    418 return model_kwargs

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py:904, in T5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    902 if inputs_embeds is None:
    903     assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
--> 904     inputs_embeds = self.embed_tokens(input_ids)
    906 batch_size, seq_length = input_shape
    908 # required mask seq length can be calculated via length of past

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/modules/sparse.py:158, in Embedding.forward(self, input)
    157 def forward(self, input: Tensor) -> Tensor:
--> 158     return F.embedding(
    159         input, self.weight, self.padding_idx, self.max_norm,
    160         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/miniconda3/envs/.env/lib/python3.9/site-packages/torch/nn/functional.py:2183, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2177     # Note [embedding_renorm set_grad_enabled]
   2178     # XXX: equivalent to
   2179     # with torch.no_grad():
   2180     #   torch.embedding_renorm_
   2181     # remove once script supports set_grad_enabled
   2182     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2183 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

Training happens in GPU I have confirmed, so I am not sure what’s left in CPU for this error to appear. Any guidance would be appreciated.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source