Skip to content

Commit

Permalink
Fix dtype mismatches in rejection_resample.
Browse files Browse the repository at this point in the history
Cast `initial_dist` and `target_dist` to `float32`.

Normally TensorFlow doesn't cast things for you, but I think in this case it's reasonable because:

- The cast is once per dataset, not in the inner loop.
- A float32 can handle numbers as small as `1e-38`. If your sample probability is near that small it's already effectively zero.

PiperOrigin-RevId: 373045036
Change-Id: I19de99280d44d02010c34514ea53a7ef4e2fedc5
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed May 11, 2021
1 parent a73de08 commit 130e833
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,30 @@ def testExhaustion(self):

self.assertAllClose(target_dist, bincount, atol=1e-2)

@parameterized.parameters(
("float32", "float64"),
("float64", "float32"),
("float64", "float64"),
("float64", None),
)
def testOtherDtypes(self, target_dtype, init_dtype):
target_dist = np.array([0.5, 0.5], dtype=target_dtype)

if init_dtype is None:
init_dist = None
else:
init_dist = np.array([0.5, 0.5], dtype=init_dtype)

dataset = dataset_ops.Dataset.range(10)
resampler = resampling.rejection_resample(
class_func=lambda x: x % 2,
target_dist=target_dist,
initial_dist=init_dist)

dataset = dataset.apply(resampler)
get_next = self.getNext(dataset)
self.evaluate(get_next())


if __name__ == "__main__":
test.main()
12 changes: 9 additions & 3 deletions tensorflow/python/data/experimental/ops/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
target_dist_t = math_ops.cast(target_dist_t, dtypes.float32)

# Get initial distribution.
if initial_dist is not None:
initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32)

acceptance_dist, prob_of_original = (
_calculate_acceptance_probs_with_mixing(initial_dist_t,
target_dist_t))
Expand Down Expand Up @@ -168,9 +171,12 @@ def _gather_and_copy(acceptance_prob, data):

current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
(acceptance_dist_ds, dataset)).map(_gather_and_copy)
filtered_ds = (
current_probabilities_and_class_and_data_ds.filter(
lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))

def _reject(unused_class_val, p, unused_data):
return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p

filtered_ds = current_probabilities_and_class_and_data_ds.filter(_reject)

return filtered_ds.map(lambda class_value, _, data: (class_value, data))


Expand Down

0 comments on commit 130e833

Please sign in to comment.