Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMC sampler #449

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

HMC sampler #449

wants to merge 7 commits into from

Conversation

furibec
Copy link
Collaborator

@furibec furibec commented Jul 1, 2024

  • The file _hmc.py contained the NUTS implementation, so I renamed it to _nuts.py
  • Re-worked a bit the NUTS implementation (there was no typo as said before), to avoid confusion I created the neg_Hamiltonian function and added some explanation, improved epsilon adaptation functions, and a few changes in the BuildTree while I was checking everything looks correct.
  • Created the HMC sampler.
  • Added demo38 that implements one of the examples from Neal's reference, and compared it with NUTS. The results are correct

@furibec furibec added the enhancement New feature or request label Jul 1, 2024
@furibec
Copy link
Collaborator Author

furibec commented Jul 2, 2024

These are the step sizes (epsilon) results for NUTS and HMC. They both use the same Dual Averaging algorithm to adapt the epsilon. Note, however, that in HMC after warm-up, the step sizes keep changing. This is because we have to do a jittering of the epsilons in order to avoid pathological behavior
image

@furibec
Copy link
Collaborator Author

furibec commented Jul 2, 2024

and this is the result for the main statistics estimated by NUTS and HMC. For HMC we fixed the trajectory length to 150 (as per Neal's example), this is the main contribution of NUTS, which adapts this parameter using the no-U-turn BuildTree (other than that, NUTS is plain HMC). Here we are sampling a 100-dimensional Gaussian with zero mean and diagonal covariance. The samples are both estimating very well the reference mean and standard deviations.

Explanation: Left plot - closer to zero is better. Right plot - closer to the (0,0)-(1,1) line is better. We can even compare with RWM and see that it is not able to estimate this problem as good as HMC/NUTS
image

@furibec
Copy link
Collaborator Author

furibec commented Jul 2, 2024

I believe the samplers are working as expected and they are also in CUQIpy format. Nevertheless, some test are failing due to the changes I did in NUTS. I hope someone can take it from here @jakobsj

@jakobsj
Copy link
Contributor

jakobsj commented Jul 4, 2024

Thank you very much @furibec for adding your HMC sampler, the fixes as well as a demo validating the implementation, that is most helpful! I cannot promise when we'll be able to look at this but I hope it will be soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants