Skip to content

Commit

Permalink
Update distribute_lib_tests to assertAllClose instead of assertEqual …
Browse files Browse the repository at this point in the history
…(since we shouldn't expect exact floating point equality).

PiperOrigin-RevId: 623605028
  • Loading branch information
srvasude authored and tensorflower-gardener committed Apr 10, 2024
1 parent 63b2100 commit c6b3eb7
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions tensorflow_probability/python/internal/distribute_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def log_prob_parts(value):

sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
log_prob_parts, [None, True])
self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])),
self.evaluate([
normal.Normal(0., 1.).log_prob(0.),
Expand All @@ -336,7 +336,7 @@ def log_prob_parts(value):

sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
log_prob_parts, [True, True])
self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(sharded_log_prob_parts([tf.ones(4), data])),
self.evaluate([
tf.reduce_sum(normal.Normal(0., 1.).log_prob(tf.ones(4))),
Expand Down Expand Up @@ -365,7 +365,7 @@ def log_prob_parts(value):
out_parts = self.per_replica_to_tensor(
self.strategy_run(run, (x, sharded_data), in_axes=(None, 0)))

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_parts),
self.evaluate([
tf.ones(4) * normal.Normal(0., 1.).log_prob(0.),
Expand Down Expand Up @@ -395,7 +395,7 @@ def log_prob_parts(value):
out_parts = self.per_replica_to_tensor(
self.strategy_run(run, (sharded_x, sharded_data)))

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_parts),
self.evaluate([
tf.ones(4) * tf.reduce_sum(normal.Normal(0., 1.).log_prob(x)),
Expand Down Expand Up @@ -428,7 +428,7 @@ def log_prob_parts(values):
self.strategy_run(
run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0)))

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_parts),
self.evaluate([
tf.ones(4) * normal.Normal(0., 1.).log_prob(w),
Expand Down Expand Up @@ -467,7 +467,7 @@ def true_log_prob(x):

true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1])

self.assertAllEqualNested(self.evaluate(out_grads), tf.ones(4) * true_grad)
self.assertAllCloseNested(self.evaluate(out_grads), tf.ones(4) * true_grad)

def test_correct_gradient_for_local_variable(self):

Expand Down Expand Up @@ -502,7 +502,7 @@ def true_log_prob(x):

true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1])

self.assertAllEqualNested(self.evaluate(out_grads), true_grad)
self.assertAllCloseNested(self.evaluate(out_grads), true_grad)

def test_correct_gradient_for_global_and_local_variable(self):

Expand Down Expand Up @@ -543,7 +543,7 @@ def true_log_prob(*value):
true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1]
true_grad[0] = tf.ones(4) * true_grad[0]

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_grads), self.evaluate(true_grad))

def test_correct_gradient_for_global_and_local_variable_batched(self):
Expand Down Expand Up @@ -593,7 +593,7 @@ def true_log_prob(*value):
true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1]
true_grad[0] = tf.ones([batch_size, 4, 1]) * true_grad[0][:, tf.newaxis]

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_grads), self.evaluate(true_grad))

def test_correct_gradient_for_global_and_local_variable_dict(self):
Expand Down Expand Up @@ -639,7 +639,7 @@ def true_log_prob(*value):
true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1]
true_grad[0] = tf.ones(4) * true_grad[0]

self.assertAllEqualNested(
self.assertAllCloseNested(
self.evaluate(out_grads), self.evaluate(true_grad))

def test_correct_gradient_for_local_integer_variable(self):
Expand Down Expand Up @@ -674,8 +674,7 @@ def true_log_prob(x):
tf.reduce_sum(bernoulli.Bernoulli(logits=x).log_prob(data)))

true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1])

self.assertAllEqualNested(self.evaluate(out_grads), true_grad)
self.assertAllCloseNested(self.evaluate(out_grads), true_grad)

def test_correct_gradient_dtype_for_disconnected_variables(self):

Expand Down Expand Up @@ -750,10 +749,10 @@ def true_log_prob(x, data1, data2):
true_values, true_grads = self.evaluate(
gradient.value_and_gradient(true_log_prob, (x, data1, data2)))

self.assertAllEqualNested(out_values, tf.ones([2, 2]) * true_values)
self.assertAllEqualNested(out_grads[0], tf.ones([2, 2]) * true_grads[0])
self.assertAllEqualNested(out_grads[1], tf.ones([2, 2]) * true_grads[1])
self.assertAllEqualNested(out_grads[2], tf.ones([2, 2]) * true_grads[2])
self.assertAllCloseNested(out_values, tf.ones([2, 2]) * true_values)
self.assertAllCloseNested(out_grads[0], tf.ones([2, 2]) * true_grads[0])
self.assertAllCloseNested(out_grads[1], tf.ones([2, 2]) * true_grads[1])
self.assertAllCloseNested(out_grads[2], tf.ones([2, 2]) * true_grads[2])

def test_nested_shard_axes(self):
if not JAX_MODE:
Expand Down Expand Up @@ -795,9 +794,9 @@ def true_log_prob(x, data):
true_values, true_grads = self.evaluate(
gradient.value_and_gradient(true_log_prob, (x, data)))

self.assertAllEqualNested(out_values, tf.ones([2, 2]) * true_values)
self.assertAllEqualNested(out_grads[0], tf.ones([2, 2]) * true_grads[0])
self.assertAllEqualNested(out_grads[1], tf.ones([2, 2]) * true_grads[1])
self.assertAllCloseNested(out_values, tf.ones([2, 2]) * true_values)
self.assertAllCloseNested(out_grads[0], tf.ones([2, 2]) * true_grads[0])
self.assertAllCloseNested(out_grads[1], tf.ones([2, 2]) * true_grads[1])

def test_gradient_is_correctly_reduced_with_multiple_axes(self):
if not JAX_MODE:
Expand Down Expand Up @@ -850,11 +849,11 @@ def true_log_prob(x, y, z):

self.assertAllClose(
out_values, tf.ones([2, 2]) * true_values, rtol=1e-6, atol=1e-6)
self.assertAllEqualNested(out_grads[0],
self.assertAllCloseNested(out_grads[0],
tf.ones([2, 2]) * true_grads[0][:, None])
self.assertAllEqualNested(out_grads[1],
self.assertAllCloseNested(out_grads[1],
tf.ones([2, 2]) * true_grads[1][None])
self.assertAllEqualNested(out_grads[2], tf.ones([2, 2]) * true_grads[2])
self.assertAllCloseNested(out_grads[2], tf.ones([2, 2]) * true_grads[2])


if __name__ == '__main__':
Expand Down

0 comments on commit c6b3eb7

Please sign in to comment.