Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax benchmarks #784

Merged
merged 8 commits into from
Jan 15, 2025
Merged

Jax benchmarks #784

merged 8 commits into from
Jan 15, 2025

Conversation

s3alfisc
Copy link
Member

@s3alfisc s3alfisc commented Jan 9, 2025

Add JAX benchmarking notebook.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Flag Coverage Δ
core-tests 82.55% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

benchmarks/gpu_pyfixest_errors.ipynb Outdated Show resolved Hide resolved
benchmarks/gpu_pyfixest_errors.ipynb Outdated Show resolved Hide resolved
benchmarks/gpu_pyfixest_errors.ipynb Outdated Show resolved Hide resolved
benchmarks/gpu_pyfixest_errors.ipynb Outdated Show resolved Hide resolved
@juanitorduz
Copy link
Contributor

I left some minor comments.

Despite this is a dev notebook. I think it would be nice to have a title, a description of what's happening, and the motivation (and conclusions?) . Think a bout a new dev reading this in 6 months ;) (it could be you! hehe)

@s3alfisc
Copy link
Member Author

Thanks Juan! I think I should have adressed all your comments in 279167a =)

@s3alfisc
Copy link
Member Author

I was also think that we might just add this nb to the docs, it might be interesting for users to see these benchmarks?

@@ -0,0 +1,1494 @@
{
Copy link
Contributor

@juanitorduz juanitorduz Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the results chage? I thought JAX on GPU was faster

Also, the y ticks are hard to read ;)


Reply via ReviewNB

Copy link
Member Author

@s3alfisc s3alfisc Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the y-ticks - I was hoping they'd get easier to read once the benchmarks ran for longer, say with 1Mio observations (though am not sure if seaborn actually supports this type of auto-scaling, but I would suppose so?)

@s3alfisc
Copy link
Member Author

Hi, no, this is still all run on my Laptop, so no GPUs involved. The idea was that @apoorvalal could use this nb to run benchmarks on the GPU.

@juanitorduz
Copy link
Contributor

got it! thanks :) (then just the comment about the y -ticks ;) )

@apoorvalal
Copy link
Member

Working on it; immediately ran into some cuda issues (jax is great on goog's TPUs but much more fiddly on nvidia gpus; torch is much better on those) so had to get some sysadmins to fix. Should be able to finish running today

@s3alfisc
Copy link
Member Author

Awesome, I'm super curious 😀

@s3alfisc
Copy link
Member Author

And thanks!

@apoorvalal
Copy link
Member

apoorvalal commented Jan 13, 2025

Ran and added results (csv and notebook). As expected, pretty big performance gains from jax in large problems.

image

Much more promising than gpu_benchmarks.ipynb, which doesn't actually use a GPU (jnp.ones(10).devices() shows {CpuDevice(id=0)}, i.e. the array is being created on the CPU, so JAX isn't using a GPU at all), so that notebook should be renamed to reflect that it is JAX performance on CPU, and probably suggests that the JAX backend is only worthwhile on a powerful GPU.

Also, I noticed that the jax demeaning algo is potentially inefficient; it doesn't have early stopping (has_converged never triggers a break in apply_factor or body_fun) so always goes over a fixed number of iterations. I can take a crack at improving it, but that might be best left for a separate PR.

@apoorvalal
Copy link
Member

also cc @iamlemec ; any thoughts on how the alternating projections step is implemented here? This isn't quite as general as your MLE implementation in fastreg (since the primary problem has a closed form solution).

@s3alfisc
Copy link
Member Author

Wow this looks super good! Thanks @apoorvalal! Basically we see a 2-3x speed increase for big problems? Do you happen to know how many CPUs numba had available?

Also, I noticed that the jax demeaning algo is potentially inefficient; it doesn't have early stopping (has_converged never triggers a break in apply_factor or body_fun) so always goes over a fixed number of iterations.

😅 about this. So JAX really always runs the full maximum 100_000k iterations? Then we shouldn't be surprised that we don't see any speed improvements on the CPU (but really worse performance). Seeing the glass half full I'd say even more scope for performance improvements :D

@s3alfisc
Copy link
Member Author

Wow, I wasn't aware that fastreg runs GLMs on JAX (cc @juanitorduz) - if you have any suggestions on how to optimize demean_jax, that would be super appreciated @iamlemec! I think we're all still starting out with JAX and at least I still have a lot to learn about performance optimization (and even more basic things, as I apparently missed a break statement in the while clause...).

@s3alfisc
Copy link
Member Author

@apoorvalal I think that the while loop breaks here

final_i, final_x, _, converged = jax.lax.while_loop(
if _conf_fun is False, or not?
def _cond_fun(state):

@apoorvalal
Copy link
Member

apoorvalal commented Jan 13, 2025

Oh right that makes sense; I was tripped up by the haskell-y separate conditional function.

Oh also the instance I ran on had 12 CPU cores, so I'm assuming numba has access to all of them?

@s3alfisc
Copy link
Member Author

s3alfisc commented Jan 13, 2025

Oh also the instance I ran on had 12 CPU cores, so I'm assuming numba has access to all of them?

Yes, it should use all available CPUs by default 👍

Oh right that makes sense; I was tripped up by the haskell-y separate conditional function.

I also think that reading JAX code takes some time getting used too...

Copy link
Member

@apoorvalal apoorvalal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks ready to merge after some cleanup; maybe delete the old GPU benchmarks notebook and lightly edit the GPU benchmark one and integrate it into the docs page as well.

My conclusion from the benchmarking exercise is that the JAX backend is only really useful for very large datasets with access to fairly specialised GPUs, so we should keep the numba backend as the default and leave JAX as a poweruser feature.

@iamlemec
Copy link

Thanks for tagging me in @apoorvalal and @s3alfisc! Not surprisingly, this already looks really good to me. I tried my darnest to eek out some more performance gains (see this gist: demean_jax_.py) and I maybe got 5-10%, but it's probably configuration dependent.

My main takeaways were:

  • You can factor out the nested functions and just have one top-level jit, but it doesn't seem to make a huge difference performance wise
  • For the call to scan, you can get away with slicing through flist directly, rather than using a range to address it
  • You actually don't need the x_prev part of the state in the current implementation (I'm assuming that was vestigial)
  • You can get away with not passing tol to static_argnums since its a float

One thing I'm confused about: why don't the large gains on the demean step pass through one-to-one to the full feols estimation? I would think it's a subset of the full op. Also, are there plans to bring JAX to other parts of the code?

@s3alfisc
Copy link
Member Author

s3alfisc commented Jan 13, 2025

Wow, thanks @iamlemec!

One thing I'm confused about: why don't the large gains on the demean step pass through one-to-one to the full feols estimation?

Where do you see this? Could it be costs of type conversions from numpy to jax and moving data from CPU to GPU and back? Note that these benchmarks all run on my private laptop / only on the CPU as @apoorvalal points out.

Also, are there plans to bring JAX to other parts of the code?

Yes! We already allow users to run the OLS fit via jnp.linalg.lstsq. It's not included in the benchmark as it leads to an error with CRV standard errors which runs on numba.

Besides the demeaning and solving the OLS system of equations, I'd like to explore if JAX can help in all other performance critical pieces that currently run through numba:

Maybe it also makes sense for computing HC2 and HC3 errors and covariance matrices in general?

The GLM implementations I would for now leave untouched (except for allowing to call demean_jax() once we add fixed effects support).

Generally we need to think a little bit about where we want to go with JAX - based on the CPU benchmarks, doing a full rewrite to a JAX-first backend seems like a bad solution to me? As long as JAX doesn't beat numba on the CPU, I think that numba should remain the default backend (as I think most users will run their regressions on the CPU).

My preferred approach at the moment would be to define a nice API that would provide users with control over which different performance critical code segments to run in JAX / on the GPU if desired. Maybe something like this:

def feols(
    fml: str, 
    data: DataFrameType, 
    ..., 
    gpu: bool = True,  
    gpu_options: dict = None,  
)

or even just one argument.

One drawback of this approach would potentially be repeated translations from numpy to jax and vice versa.

Regarding your JAX implementation and the performance improvements - that's of course something I'm very interested in, would you be up to open a PR? 😄

@apoorvalal
Copy link
Member

jnp.linear is likely not the fastest jax-based alternative for solving least-squares systems fwiw. Lineax has some nice automatic methods to pick the right decomposition method and seems to show pretty good improvements over scikit's regression class (which uses numpy/scipy solvers)

then again, given how mature pyfixest's internals are, I would be in favour of a separate experimental jax-based econometrics package that we develop and iterate on faster and potentially integrate into pyfixest further down the line.

@s3alfisc
Copy link
Member Author

s3alfisc commented Jan 13, 2025

then again, given how mature pyfixest's internals are, I would be in favour of a separate experimental jax-based econometrics package that we develop and iterate on faster and potentially integrate into pyfixest further down the line.

I'd surely be up for it (might really make more sense than trying to squeeze things into pf)! If we were to go ahead with this, I'd try to involve @janosg and @timmens - did the two of you ever got around to sketch out the design doc for the JAX-first stats package you had in mind?

@janosg
Copy link

janosg commented Jan 14, 2025

We did not get very far on this project and I don't have capacity to work on it right now. But we could have a call to exchange some ideas.

I think the ideal solution would be to make pyfixest more extensible such that the jax implementations of the numeric parts don't have to be part of of pyfixest but the rest (e.g. formula parsing, visualizations, ...) does not have to be re-implemented in a jax package. It would be similar to an optimization package like scipy optimize or optimagic where you can either select pre-implemented solvers or bring your own as long as it satisfies a certain interface

@s3alfisc
Copy link
Member Author

But we could have a call to exchange some ideas.

Sounds like a good first step @janosg - I'd be happy to try to arrange this and invite everyone interested (please responding with a 👍if you are!)

I think the ideal solution would be to make pyfixest more extensible

I think this is a great suggestion - for the beginning, we could continue to port core pieces of the code to JAX and collect it in a sub module. We could try to write the code in a very functional style that would make it easy to be reused in different code bases, so that later on, we could decide to spin it out (or just the most useful pieces) into a standalone package?

@s3alfisc
Copy link
Member Author

s3alfisc commented Jan 14, 2025

I've opened a PR to move all JAX code into a jax module + have implemented a POC JAXOLS class that does everything in JAX (demeaning, fitting, inference).

@iamlemec
Copy link

Thanks for the info @s3alfisc! About the timing stats, I think I was confused about the numbers coming out of the GPU benchmark notebook. Still getting 5-10% boost and things make sense now.

I'll wait till #790 settles and then submit a PR. I'll poke around a bit too and see if there are any good targets for JAX conversion, HC* stuff does seem promising. I agree that repeated transfers back and forth might be an issue, which kind of pushes one towards a full-pipeline approach.

Would happily contribute to a JAX-first stats package too! Count me in if there's a call/discussion.

@s3alfisc
Copy link
Member Author

Thanks all for your responses! I'll send out an email this weekend then =) @iamlemec I'll ping you once I get to merge #790 .

@s3alfisc s3alfisc merged commit 2425240 into master Jan 15, 2025
9 checks passed
@s3alfisc s3alfisc deleted the jax-benchmarks branch January 15, 2025 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants