'Deep Invertible Generalized Linear Model (DIGLM) in tensorflow_probability

I'm trying to reproduce the model described in this article using tensorflow_probability. The abstract well explains the architecture and the purpose of this model

We propose a neural hybrid model consisting of a linear model defined on a set of features computed by a deep, invertible transformation (i.e. a normalizing flow). An attractive property of our model is that both p(features), the density of the features, and p(targets | features), the predictive distribution, can be computed exactly in a single feed-forward pass.

In practice, using tensorflow_probability you can define p(features) using tfd.TransformedDistribution with tfb.RealNVP or tfb.Glow as bijector and p(targets|features) using a generalized linear model. To train the model the articles maximize the likelihood of the joint distribution p(targets, features) or slight modified variants of the likelihood.

My problem is that is can't figure out how to include in a tfd.JointDistribution an object of tfp.glm. I tried to include the generalized linear model with the as_distribution() method, like in this minimal example

import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

jd = JointDistributionNamedAutoBatched({
     "x": tfd.MultivariateNormalDiag([0.,0.],[1.,1.]),
     "y": lambda x: tfp.glm.Binomial().as_distribution(tfp.glm.compute_predicted_linear_response(x,[2.,3.]))
})

but if I call jd i recieve this error

ValueError: Shape must be rank 2 but is rank 1 for '{{node compute_predicted_linear_response/MatVec/MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false](MultivariateNormalDiag/sample/chain_of_shift_of_scale_matvec_linear_operator/forward/shift/forward/add, compute_predicted_linear_response/MatVec/ExpandDims)' with input shapes: [2], [2,1].

And, last but not least, how can I pass GLM coefficients as trainable variables?

P.S. I hope that the question is well written, this is my first post on stackoverflow, apologies if not.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source