'Tensor shapes for FFJORD bijector

I want to fit FFJORD bijector for transformation of two-dimensional dataset. The code is below (it is simplified version of my original code, but has the same problem).

import tensorflow as tf
import tensorflow_probability as tfp

tfb = tfp.bijectors
tfd = tfp.distributions

class ODE(tf.keras.layers.Layer):
    def __init__(self):
        super(ODE, self).__init__()
        self.dense_layer1 = tf.keras.layers.Dense(4, activation = 'tanh')
        self.dense_layer2 = tf.keras.layers.Dense(2)
    def call(self, t, inputs):
        return self.dense_layer2(self.dense_layer1(inputs))

ode = ODE()
ffjord = tfb.FFJORD(state_time_derivative_fn = ode) 
base_distr = tfd.MultivariateNormalDiag(loc = tf.zeros(2), scale_diag = tf.ones(2))
td = tfd.TransformedDistribution(distribution = base_distr, bijector = ffjord)

x = tf.keras.Input(shape = (2,), dtype = tf.float32)
log_prob = td.log_prob(x)
model = tf.keras.Model(x, log_prob)

def NLL(y, log_prob):
    return -log_prob

model.compile(optimizer = tf.optimizers.Adam(1.0e-2), loss = NLL)

history = model.fit(x = X_train, y = np.zeros(X_train.shape[0]), epochs = 100, verbose = 0, batch_size = 128)

I get error in line log_prob = td.log_prob(x): ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 2)

If I try to get a sample from transformed distribution td.sample(), it produces another error, but td.sample(1) works as well as some other calls, for example

x = tf.constant([[2.0, 3.0]])
ode(-1.0, x)
ffjord.inverse(x)
ffjord.forward(x)
td.log_prob(td.sample(5))  

I guess that there is some problem with shapes, but can't understand where it is.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source