Skip to content

Error in transformed variables with manual model vectorization #297

@jdehning

Description

@jdehning

There seems to be an error when using manual model vectorization for transformed variables when the event shape of the variable is larger than 1. A minimal example would be:

import numpy as np
import pymc4 as pm
import tensorflow as tf

means = np.random.random((3,1))*5+5
noise = np.random.random((3, 10))

data = means + noise
data = data.astype('float32')


# We want to infer the means of the data:
@pm.model
def model():
    means = yield pm.HalfNormal(name='means', loc= 0, scale = 10, event_stack=3)
    means = tf.repeat(tf.expand_dims(means, axis=-1), axis=-1, repeats=10)
    likelihood = yield pm.Normal(name='likeli', loc=means, scale = 5, observed=data, 
                                 reinterpreted_batch_ndims=2)

trace = pm.sample(model(), num_samples=50, burn_in=200, use_auto_batching=False, num_chains=2)

print(means)
print(np.median(trace.posterior['model/means'], axis=(0,1)))

which leads to the following shape error:

ValueError: Dimensions must be equal, but are 2 and 3 for '{{node add_2}} = AddV2[T=DT_FLOAT](add_1, mul)' with input shapes: [2], [2,3].

When the pm.Halfnormal is replaced by pm.Normal it works without problems.

If I understood the organization of the source code correctly, this error is due to the fact that the correct number of dimensions of the event shape is not passed to inverse_log_det_jacobian and forward_log_det_jacobian of tensorflow probability, for example exactly here: https://github.com/pymc-devs/pymc4/blob/master/pymc4/distributions/transforms.py#L131. Somehow the Transform class should have also have the number of dimensions of the event_shape as an attribute, to be able to calculate the determinant of the Jacobian correctly.

But eventually, I am using PyMC4 wrongly and there is another way to specify the model...

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions