-
Notifications
You must be signed in to change notification settings - Fork 10
Refactor: add transforms and transformed emulator #474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #474 +/- ##
=======================================
Coverage ? 78.05%
=======================================
Files ? 126
Lines ? 9077
Branches ? 0
=======================================
Hits ? 7085
Misses ? 1992
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Co-authored-by: Radka Jersakova <[email protected]>
d16eda9
to
2902587
Compare
@radka-j: thanks for the review comments! I've aimed to address these now and also more generally:
|
# TODO: PCA/VAE both require StandardizeTransform for numerical stability | ||
# e.g. "ValueError: Input tensor y contains non-finite values" | ||
# TODO: check error when no target transforms are provided | ||
# None, | ||
# [StandardizeTransform()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can probably be removed now since it seems reasonable to expect these transforms to work in this context with StandardizeTransform
applied first.
# TODO: PCA/VAE both require StandardizeTransform for numerical stability | ||
# e.g. "ValueError: Input tensor y contains non-finite values" | ||
# TODO: check error when no target transforms are provided | ||
# None, | ||
# [StandardizeTransform()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert isinstance(y_pred_cov, TensorLike) | ||
assert isinstance(y_pred2_cov, TensorLike) | ||
print(y_pred2_cov - y_pred_cov) | ||
# TODO: consider if this is close enough for PCA case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a large value for atol
though I think this seems a reasonable discrepancy given the scaling with n_samples
explored in this notebook
# TODO: these are not necessarily expected to be close since both approximate in | ||
# different ways | ||
# Most are within 50% error | ||
assert torch.quantile(diff_abs.flatten(), 0.9).item() < 0.25 | ||
assert torch.quantile(diff_abs.flatten(), 0.95).item() < 0.5 | ||
# Some large max differences so will not assert on these | ||
print("Max diff", diff_abs.abs().max()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to PCA above but to a greater extent - the similarity used in the assert here would ideally be closer and based on a specific expectation. It might be worth considering a test with an alternative dataset for this in case it relates to the quality of the VAE fit.
Closes #348 and initial impl of #376.
Summary
This PR includes:
AutoEmulateTransform
class subclassingtorch.distributions.Transform
that:TensorLike
,DistributionLike
and with custom functionality forGaussianLike
outputs.expanded_basis_matrix
override in application to the VAE.GaussianLike
make_positive_definite
: this function adds jitter (up to retries) on the diagonal and when this fails clamps the eigenvalues to be positive (up to retries)AutoEmulateTransforms
implemented for:PCATransform
,StandardizeTransform
(around mean and std_dev),VAETransform
(this copies the implementation from the current VAE used as a transform in the non-experimental module)TransformedEmulator
emulator class (the refactor version of theAutoEmulatePipeline
) with:transforms
for inputs andtarget_transforms
for outputs.GaussianLike
analytically / with delta method or through a sampling approachmax_targets
threshold to ensure computationally tractable.tune()
method may need to be directly added to theTransformedEmulator
so it has a slightly different API and perhaps shouldn't directly subclassEmulator
Questions
.fit()
method could be removed in favour of fitting upon initialization. We could passfunctools.partial
classes to theTransformedEmulator
in this case instead of fully initialized ones so that the init method can be performed within theTransformedEmulator
's initmake_positive_definite
be configurable in the emulator orTransformedEmulator
APITransformedEmulators
or should it be also included directly within emulators and replace the preprocessor APIComposeTransform
forAutoEmulateComposeTransform
to simplify composingAutoEmulateTransforms
Remaining tasks: