Skip to content

Commit

Permalink
in ad_checkpoint WrapHashably.__eq__, check _both_ are hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 24, 2022
1 parent 0a51a5a commit 88a212e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from functools import partial
import operator as op
from typing import Callable, Optional, List, Tuple, Sequence, Set, Union, Any
from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
FrozenSet)
import types

from absl import logging
Expand Down Expand Up @@ -302,7 +303,7 @@ def _remat_static_argnums(fun, static_argnums, args):

class WrapHashably:
val: Any
hash: Optional[int] = None
hash: int
hashable: bool

def __init__(self, val):
Expand All @@ -317,16 +318,18 @@ def __hash__(self):
return self.hash
def __eq__(self, other):
if isinstance(other, WrapHashably):
try: return self.val == other.val
except: return self.val is other.val
if self.hashable and other.hashable:
return self.val == other.val
else:
return self.val is other.val
return False

# This caching is useful to avoid retracing even when static_argnums is used.
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
# On that benchmark, including this caching makes a ~10x difference (which can
# be made arbitrary large by involving larger functions to be traced).
@weakref_lru_cache
def _dyn_args_fun(fun: Callable, static_argnums: Tuple[int, ...],
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def new_fun(*dyn_args, **kwargs):
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
Expand Down

0 comments on commit 88a212e

Please sign in to comment.