Skip to content

Commit

Permalink
Optimize canonicalize_shape
Browse files Browse the repository at this point in the history
I was looking at some profiles and noticed canonicalize_shape showing up as a noticeable
overhead in certain cases. Which makes sense, given that we carefully check all possible
cases before trying to consider integers as plausible elements (which are the most popular
_by far_). And this function is pretty hot, because it gets called any time we create a new
`ShapedArray`.

I wrote a small benchmark that repeatedly calls canonicalize_shape on a 4-sized tuple of
integers.

Before:
7.62µs ± 8%

After:
1.42µs ± 2%

So a pretty easy 5x improvement overall. And in more real cases, when resharding an array
onto 8 TPUs, 50% of the time was spent on creating shapes for avals of device buffers.

PiperOrigin-RevId: 516795311
  • Loading branch information
apaszke authored and jax authors committed Mar 15, 2023
1 parent d978dcf commit 1301968
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,11 @@ def dimension_as_value(d: DimSize):
return operator.index(d)

def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
return dim
elif (config.jax_dynamic_shapes and isinstance(dim, DArray) and
Expand All @@ -2007,7 +2012,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
elif is_special_dim_size(dim):
return dim
else:
return operator.index(dim)
raise type_error

def canonicalize_shape(shape: Shape, context: str="") -> Shape:
"""Canonicalizes and checks for errors in a user-provided shape value.
Expand Down

0 comments on commit 1301968

Please sign in to comment.