'How do I solve the error: Only one element tensors can be converted to python scalars
I am doing transfer learning with a vision transformer model and I have a bit of manipulation on the images therein.
for epoch in range(epochs):
with tqdm(train_loader) as p_bar:
for batch_idx, (data, target) in enumerate(p_bar):
data, target = data.to(device), target.to(device)
data.requires_grad = True
output = model(data, fine_tune=True)
init_pred = output.max(1, keepdim = True)[1]
# init_pred = torch.max(output.data, 1)
print(init_pred.shape, " and ", target.shape)
if init_pred.item() != target.item():
continue
loss = criterion(output, target)
model.zero_grad()
loss.backward()
data_grad = data.grad.data
uf = F.unfold(data, kernel_size=16, stride=16, padding=0)
patch = uf[..., 130]
perturbed_patch = fgsm_attack(patch, epsilon, data_grad)
uf[..., 130] = patch
perturbed_data = F.fold(uf, data.shape[-2:], kernel_size=16, stride=16, padding=0)
output = model(perturbed_data)
#check for success
final_pred = output.max(1, keepdim=True)[1] # get the index of max log-probability
if final_pred.item() == target.item():
correct += 1
#Special case for saving 0 epsilon examples
if (epsilon == 0) and (len(adv_examples) < 5):
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
else:
# Save some adv examples for visualization later
if len(adv_examples) < 5:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
# Calculate final accuracy for this epsilon
final_acc = correct/float(len(test_loader))
print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
Unfortunately I got the below error on the line with if init_pred.item() != target.item(): :
ValueError: only one element tensors can be converted to Python scalars
When I tried to print out the values of init_pred and target, I got a list of numbers. I do not know how to fix this error.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
