Skip to content

Commit

Permalink
Clarify a comment in ParsedPartitionSpec.eq_given_rank
Browse files Browse the repository at this point in the history
Previous one used an overly lax definition of "semantic equality".

PiperOrigin-RevId: 415567823
  • Loading branch information
apaszke authored and jax authors committed Dec 10, 2021
1 parent 6880e2f commit 1b9a9ff
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,22 @@ def __eq__(self, other):
self.sync == other.sync)

def eq_given_rank(self, other, rank):
# ParsedPartitionSpecs may contain trailing empty tuples that don't change
# the semantic meaning of the spec but are still valid specs. For example,
# for a 2D array, (), ((),), and ((), ()) are all valid specs meaning the
# array is fully replicated (no dimension partitioned). This method compares
# two specs for semantic equivalence and asserts they are valid specs for
# the given array rank.
"""Determines whether the specs are equivalent when considering arrays of a given rank.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
Because of those subtle differences, we use __eq__ to decide semantic
equality in general, while this method determines whether the two specs
are equivalent when applied to an array of a given rank. Note that this
relation has larger equivalence classes than __eq__ (i.e. x == y implies
x.eq_given_rank(y, rank)).
"""
assert len(self.partitions) <= rank and len(other.partitions) <= rank
min_length = min(len(self.partitions), len(other.partitions))
return (self.partitions[:min_length] == other.partitions[:min_length] and
Expand Down

0 comments on commit 1b9a9ff

Please sign in to comment.