-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize convolution batching rule performance #365
Comments
Or maybe we want something like a global configuration switch for "if group convolution calls unfold+mm" or not. |
IMO if we could have an option on which route to take (unfold+mm or cudnn group conv) would be better and user-friendly instead of optimizing for certain "old" hardware. Just optimizing for older hardware, I think, is not a good idea... |
An option on a route sounds reasonable |
Some code pointers for the implementation: Opacus' verison for per sample gradients
|
Flagging that within the past ~month there's also been a substantial perf regression using group convolutions on A100s (the newest hardware). I can check what the comparison is on V100s + P100s to get data across the board. I still agree that we should have the flag, we just may want the default to be to use unfold |
It's unclear if unfold + matmul is actually faster as a replacement for group convolution. I did an experiment where I replaced all group convolutions with unfold + mm. For our ensembling example on a CNN, it is still not as performant as running a for-loop: |
Suggestion from Horace: add a flag to disable the batching rule for convolution |
On CUDA, when the convolution batching rule uses group convolutions, this sometimes ends up being slower that we expect on older hardware. This is probably because PyTorch's group convolution calls the cudnn group convolution which is very unoptimized on older hardware.
We should try to optimize the performance of the group convolution path. I remember that unfold+matmul can be faster at times.
The text was updated successfully, but these errors were encountered: