'PyTorch RNN gives accuracy 0.0
I tried to run the following code but the output is [0.0, 0.0]. I suppose this shld not happen, but cannot seem to figure out what could be causing this issue. Any idea what could have gone wrong? Appreciate any help, thanks
def load_array(data_arrays, batch_size, is_train=True):
"""Construct a PyTorch data iterator."""
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
data_iter = load_array((train_x, train_y), 1)
class extractlastcell(nn.Module):
def forward(self,x):
out , _ = x
return out[:, -1, :]
net= nn.Sequential(
nn.Embedding(5000, 256),
nn.LSTM(256, 32),
extractlastcell(),
nn.Linear(32, 16), nn.Softmax())
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)
def train_model(train_dl, model, epoch):
train_ls = []
# define the optimization
loss = nn.BCELoss(reduction='none')
trainer = torch.optim.Adam(net.parameters(), lr=0.1)
# enumerate epochs
for epoch in range(epoch):
# enumerate mini batches
for i, (inputs, targets) in enumerate(train_dl):
# clear the gradients
trainer.zero_grad()
# compute the model output
yhat = model(inputs)
# calculate loss
l = loss(yhat, targets)
# credit assignment
l.sum().backward()
# update model weights
trainer.step()
train_ls.append(d2l.evaluate_accuracy(net, train_dl))
return train_ls
train_model(data_iter,net,2)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
