'Pytorch `bachward()` updates multiple models

Can anyone tell me why the gradients of the discriminator change as well and if there is a way to avoid it?

for i in range(2):

    X_fake = gen_model(z)

    pred_real = disc_model(X)
    pred_fake = disc_model(X_fake.detach())
    disc_loss = (loss_fn(pred_real, y) + loss_fn(pred_fake, y)) / 2

    disc_optimizer.zero_grad()
    disc_loss.backward()
    disc_optimizer.step()

    pred_fake = disc_model(X_fake)
    gen_loss = loss_fn(pred_fake, y)

    gen_optimizer.zero_grad()
    i == 1 and print_grads(disc_model) # Checkpoint 1
    gen_loss.backward()
    i == 1 and print_grads(disc_model) # Checkpoint 2
    gen_optimizer.step()

This is the rest of the code.

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self._linear = nn.Sequential( nn.Linear(1, 5) )

    def forward(self, X):
        return self._linear(X)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self._linear = nn.Sequential( nn.Linear(5, 1) )

    def forward(self, X):
        return self._linear(X)

def print_grads(model):
    for params in model.parameters():
        print(params.grad)

# Build the model and data.
gen_model = Generator()
gen_optimizer = torch.optim.Adam(gen_model.parameters(), 1)

disc_model = Discriminator()
disc_optimizer = torch.optim.Adam(disc_model.parameters(), 1)

loss_fn = torch.nn.BCEWithLogitsLoss()

z = torch.rand((1, 1))
X = torch.rand((1, 5))
y = torch.rand((1, 1))


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source