'Pytorch: Finetuning the inputs to a non-pytorch model

I want to finetune the inputs to a XGBClassifier model using pytorch.

I have an XGB classifier model which takes in a vector comprised of floats and integer values. Instead of passing the input data directly into the XGB.fit() method I want to first non-linearly modify the values using a MLP defined in pytorch. Then pass those modified inputs (output from the MLP) into the XGB classifier, and then use the loss from that to update the MLP weights.

The problem I'm running into is that this breaks the computational graph. Passing the MLP output to the XGB classifier requires detaching the output from the computational graph, which then doesn't allow me to backprop the loss term.

xgb_model = XGBClassifier(**params)
mlp_model = nn.Sequential(
   nn.Linear(inputs.shape[1], 64),
   nn.ReLU(),
   nn.Linear(64, 64),
   nn.ReLU(),
)

for e in range(config.epochs):
    for i, data in enumerate(data_loader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        # Forward pass through MLP model
        encoded_inputs = mlp_model(inputs)

        # Fit the XGB classifier on MLP output
        xgb_model.fit(encoded_inputs.cpu().detach().numpy(), labels.cpu().numpy())

        # Generate predictions from fit XGB model
        output = xgb_model.predict_proba(encoded_inputs)[:,1][:,None]

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

The .backward() call on the loss term results in the following error

RuntimeError: element 0 of tensor does not require grad and does not have a grad_fn

Is there a way to attach the outupts from the XGB classifier used in the loss term calculation to the encoded_input from the MLP in the computational graph? So that pytorch associates the predictions from the XGB classifier with the output from the MLP?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source