'Pytorch `bachward()` updates multiple models
Can anyone tell me why the gradients of the discriminator change as well and if there is a way to avoid it?
for i in range(2):
X_fake = gen_model(z)
pred_real = disc_model(X)
pred_fake = disc_model(X_fake.detach())
disc_loss = (loss_fn(pred_real, y) + loss_fn(pred_fake, y)) / 2
disc_optimizer.zero_grad()
disc_loss.backward()
disc_optimizer.step()
pred_fake = disc_model(X_fake)
gen_loss = loss_fn(pred_fake, y)
gen_optimizer.zero_grad()
i == 1 and print_grads(disc_model) # Checkpoint 1
gen_loss.backward()
i == 1 and print_grads(disc_model) # Checkpoint 2
gen_optimizer.step()
This is the rest of the code.
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self._linear = nn.Sequential( nn.Linear(1, 5) )
def forward(self, X):
return self._linear(X)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self._linear = nn.Sequential( nn.Linear(5, 1) )
def forward(self, X):
return self._linear(X)
def print_grads(model):
for params in model.parameters():
print(params.grad)
# Build the model and data.
gen_model = Generator()
gen_optimizer = torch.optim.Adam(gen_model.parameters(), 1)
disc_model = Discriminator()
disc_optimizer = torch.optim.Adam(disc_model.parameters(), 1)
loss_fn = torch.nn.BCEWithLogitsLoss()
z = torch.rand((1, 1))
X = torch.rand((1, 5))
y = torch.rand((1, 1))
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
