-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax benchmarks #784
Jax benchmarks #784
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Flags with carried forward coverage won't be shown. Click here to find out more. |
I left some minor comments. Despite this is a dev notebook. I think it would be nice to have a title, a description of what's happening, and the motivation (and conclusions?) . Think a bout a new dev reading this in 6 months ;) (it could be you! hehe) |
Thanks Juan! I think I should have adressed all your comments in 279167a =) |
I was also think that we might just add this nb to the docs, it might be interesting for users to see these benchmarks? |
@@ -0,0 +1,1494 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the results chage? I thought JAX on GPU was faster
Also, the y ticks are hard to read ;)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the y-ticks - I was hoping they'd get easier to read once the benchmarks ran for longer, say with 1Mio observations (though am not sure if seaborn actually supports this type of auto-scaling, but I would suppose so?)
Hi, no, this is still all run on my Laptop, so no GPUs involved. The idea was that @apoorvalal could use this nb to run benchmarks on the GPU. |
got it! thanks :) (then just the comment about the y -ticks ;) ) |
Working on it; immediately ran into some cuda issues (jax is great on goog's TPUs but much more fiddly on nvidia gpus; torch is much better on those) so had to get some sysadmins to fix. Should be able to finish running today |
Awesome, I'm super curious 😀 |
And thanks! |
Ran and added results (csv and notebook). As expected, pretty big performance gains from jax in large problems. Much more promising than Also, I noticed that the jax demeaning algo is potentially inefficient; it doesn't have early stopping ( |
also cc @iamlemec ; any thoughts on how the alternating projections step is implemented here? This isn't quite as general as your MLE implementation in fastreg (since the primary problem has a closed form solution). |
Wow this looks super good! Thanks @apoorvalal! Basically we see a 2-3x speed increase for big problems? Do you happen to know how many CPUs numba had available?
😅 about this. So JAX really always runs the full maximum 100_000k iterations? Then we shouldn't be surprised that we don't see any speed improvements on the CPU (but really worse performance). Seeing the glass half full I'd say even more scope for performance improvements :D |
Wow, I wasn't aware that |
@apoorvalal I think that the while loop breaks here
_conf_fun is False, or not?
|
Oh right that makes sense; I was tripped up by the haskell-y separate conditional function. Oh also the instance I ran on had 12 CPU cores, so I'm assuming numba has access to all of them? |
Yes, it should use all available CPUs by default 👍
I also think that reading JAX code takes some time getting used too... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks ready to merge after some cleanup; maybe delete the old GPU benchmarks notebook and lightly edit the GPU benchmark one and integrate it into the docs page as well.
My conclusion from the benchmarking exercise is that the JAX backend is only really useful for very large datasets with access to fairly specialised GPUs, so we should keep the numba backend as the default and leave JAX as a poweruser feature.
Thanks for tagging me in @apoorvalal and @s3alfisc! Not surprisingly, this already looks really good to me. I tried my darnest to eek out some more performance gains (see this gist: demean_jax_.py) and I maybe got 5-10%, but it's probably configuration dependent. My main takeaways were:
One thing I'm confused about: why don't the large gains on the demean step pass through one-to-one to the full feols estimation? I would think it's a subset of the full op. Also, are there plans to bring JAX to other parts of the code? |
Wow, thanks @iamlemec!
Where do you see this? Could it be costs of type conversions from numpy to jax and moving data from CPU to GPU and back? Note that these benchmarks all run on my private laptop / only on the CPU as @apoorvalal points out.
Yes! We already allow users to run the OLS fit via jnp.linalg.lstsq. It's not included in the benchmark as it leads to an error with CRV standard errors which runs on Besides the demeaning and solving the OLS system of equations, I'd like to explore if JAX can help in all other performance critical pieces that currently run through Maybe it also makes sense for computing HC2 and HC3 errors and covariance matrices in general? The GLM implementations I would for now leave untouched (except for allowing to call Generally we need to think a little bit about where we want to go with JAX - based on the CPU benchmarks, doing a full rewrite to a JAX-first backend seems like a bad solution to me? As long as JAX doesn't beat numba on the CPU, I think that numba should remain the default backend (as I think most users will run their regressions on the CPU). My preferred approach at the moment would be to define a nice API that would provide users with control over which different performance critical code segments to run in JAX / on the GPU if desired. Maybe something like this: def feols(
fml: str,
data: DataFrameType,
...,
gpu: bool = True,
gpu_options: dict = None,
) or even just one argument. One drawback of this approach would potentially be repeated translations from numpy to jax and vice versa. Regarding your JAX implementation and the performance improvements - that's of course something I'm very interested in, would you be up to open a PR? 😄 |
jnp.linear is likely not the fastest jax-based alternative for solving least-squares systems fwiw. Lineax has some nice automatic methods to pick the right decomposition method and seems to show pretty good improvements over scikit's regression class (which uses numpy/scipy solvers) then again, given how mature pyfixest's internals are, I would be in favour of a separate experimental jax-based econometrics package that we develop and iterate on faster and potentially integrate into pyfixest further down the line. |
I'd surely be up for it (might really make more sense than trying to squeeze things into pf)! If we were to go ahead with this, I'd try to involve @janosg and @timmens - did the two of you ever got around to sketch out the design doc for the JAX-first stats package you had in mind? |
We did not get very far on this project and I don't have capacity to work on it right now. But we could have a call to exchange some ideas. I think the ideal solution would be to make pyfixest more extensible such that the jax implementations of the numeric parts don't have to be part of of pyfixest but the rest (e.g. formula parsing, visualizations, ...) does not have to be re-implemented in a jax package. It would be similar to an optimization package like scipy optimize or optimagic where you can either select pre-implemented solvers or bring your own as long as it satisfies a certain interface |
Sounds like a good first step @janosg - I'd be happy to try to arrange this and invite everyone interested (please responding with a 👍if you are!)
I think this is a great suggestion - for the beginning, we could continue to port core pieces of the code to |
I've opened a PR to move all JAX code into a |
Thanks for the info @s3alfisc! About the timing stats, I think I was confused about the numbers coming out of the GPU benchmark notebook. Still getting 5-10% boost and things make sense now. I'll wait till #790 settles and then submit a PR. I'll poke around a bit too and see if there are any good targets for JAX conversion, HC* stuff does seem promising. I agree that repeated transfers back and forth might be an issue, which kind of pushes one towards a full-pipeline approach. Would happily contribute to a JAX-first stats package too! Count me in if there's a call/discussion. |
Add JAX benchmarking notebook.