'What is the TF equivalent for the "obs" argument in Torch pyro.sample()?
I am stuck on this question trying to convert a VAE in PyTorch/Pyro to TFP/Edward2. The original Code can be found here.
So the basic question is how to convert these code snippet to TFP/Edward2.
with pyro.plate("data", x.size(0)):
t = pyro.sample("t", self.t_dist(z), obs=t)
with pyro.plate("data", size, subsample=x):
t = pyro.sample("t", self.t_dist(z), obs=t, infer={"is_auxiliary": True})
The t
is of shape t.shape = [673, 1]
, the x
of shape x.shape = [673, 25]
and the self.t_dist
is a Bernoulli Distribution Network. z
is sampled from a normal distributionz = pyro.sample("z", self.z_dist())
with:
def z_dist(self):
return dist.Normal(0, 1).expand([15]).to_event(1)
Thanks for the Help!!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|