'TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type

I'm having some issues testing a basic model on Jax. For example, I'm trying to implement the value_and_grad() function from Jax manually for a binary classification problem. Here is my model initializer:

class MLP(nn.Module):
    num_neurons_per_layer: Sequence[int]

    @nn.compact
    def __call__(self, x):
        activation = x
        for i, num_neurons in enumerate(self.num_neurons_per_layer):
            activation = nn.Dense(num_neurons)(activation)
            if i != len(self.num_neurons_per_layer) - 1:
                activation = nn.relu(activation)
        return nn.sigmoid(activation)

And here is my BCE loss which is using vmap to batch the samples quicker all wrapped in a jit:

def make_bce_loss(xs, ys):
    
    def bce_loss(params, model): 
        def cross_entropy(x, y):
            preds = model.apply(params, x)
            return y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
        return -jnp.mean(jax.vmap(cross_entropy)(xs, ys), axis=0)

    return jax.jit(bce_loss)

bce_loss = make_bce_loss(X, y)
value_and_grad_fn = jax.value_and_grad(bce_loss)

Then I proceed to create the model and init the parameters:

model = MLP(num_neurons_per_layer=[4, 1])
params = model.init(key, X)  # I create a jnp.array() to create X earlier on

When I test out my jitted version of value_and_grad_fn(params, model) I get the following error:

TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type.

I'm not sure what I should be doing to correct this. It is throwing an error about the [4, 1] but those aren't involved in the calculation at all, they are only used to initialize the model in the MLP class.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source