diff --git a/tensorflow_probability/python/internal/distribute_lib_test.py b/tensorflow_probability/python/internal/distribute_lib_test.py index 6452e40ad1..1518b7d2b0 100644 --- a/tensorflow_probability/python/internal/distribute_lib_test.py +++ b/tensorflow_probability/python/internal/distribute_lib_test.py @@ -313,7 +313,7 @@ def log_prob_parts(value): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, True]) - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])), self.evaluate([ normal.Normal(0., 1.).log_prob(0.), @@ -336,7 +336,7 @@ def log_prob_parts(value): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [True, True]) - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(sharded_log_prob_parts([tf.ones(4), data])), self.evaluate([ tf.reduce_sum(normal.Normal(0., 1.).log_prob(tf.ones(4))), @@ -365,7 +365,7 @@ def log_prob_parts(value): out_parts = self.per_replica_to_tensor( self.strategy_run(run, (x, sharded_data), in_axes=(None, 0))) - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_parts), self.evaluate([ tf.ones(4) * normal.Normal(0., 1.).log_prob(0.), @@ -395,7 +395,7 @@ def log_prob_parts(value): out_parts = self.per_replica_to_tensor( self.strategy_run(run, (sharded_x, sharded_data))) - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_parts), self.evaluate([ tf.ones(4) * tf.reduce_sum(normal.Normal(0., 1.).log_prob(x)), @@ -428,7 +428,7 @@ def log_prob_parts(values): self.strategy_run( run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0))) - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_parts), self.evaluate([ tf.ones(4) * normal.Normal(0., 1.).log_prob(w), @@ -467,7 +467,7 @@ def true_log_prob(x): true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1]) - self.assertAllEqualNested(self.evaluate(out_grads), tf.ones(4) * true_grad) + self.assertAllCloseNested(self.evaluate(out_grads), tf.ones(4) * true_grad) def test_correct_gradient_for_local_variable(self): @@ -502,7 +502,7 @@ def true_log_prob(x): true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1]) - self.assertAllEqualNested(self.evaluate(out_grads), true_grad) + self.assertAllCloseNested(self.evaluate(out_grads), true_grad) def test_correct_gradient_for_global_and_local_variable(self): @@ -543,7 +543,7 @@ def true_log_prob(*value): true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1] true_grad[0] = tf.ones(4) * true_grad[0] - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_grads), self.evaluate(true_grad)) def test_correct_gradient_for_global_and_local_variable_batched(self): @@ -593,7 +593,7 @@ def true_log_prob(*value): true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1] true_grad[0] = tf.ones([batch_size, 4, 1]) * true_grad[0][:, tf.newaxis] - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_grads), self.evaluate(true_grad)) def test_correct_gradient_for_global_and_local_variable_dict(self): @@ -639,7 +639,7 @@ def true_log_prob(*value): true_grad = gradient.value_and_gradient(true_log_prob, [w, x])[1] true_grad[0] = tf.ones(4) * true_grad[0] - self.assertAllEqualNested( + self.assertAllCloseNested( self.evaluate(out_grads), self.evaluate(true_grad)) def test_correct_gradient_for_local_integer_variable(self): @@ -674,8 +674,7 @@ def true_log_prob(x): tf.reduce_sum(bernoulli.Bernoulli(logits=x).log_prob(data))) true_grad = self.evaluate(gradient.value_and_gradient(true_log_prob, x)[1]) - - self.assertAllEqualNested(self.evaluate(out_grads), true_grad) + self.assertAllCloseNested(self.evaluate(out_grads), true_grad) def test_correct_gradient_dtype_for_disconnected_variables(self): @@ -750,10 +749,10 @@ def true_log_prob(x, data1, data2): true_values, true_grads = self.evaluate( gradient.value_and_gradient(true_log_prob, (x, data1, data2))) - self.assertAllEqualNested(out_values, tf.ones([2, 2]) * true_values) - self.assertAllEqualNested(out_grads[0], tf.ones([2, 2]) * true_grads[0]) - self.assertAllEqualNested(out_grads[1], tf.ones([2, 2]) * true_grads[1]) - self.assertAllEqualNested(out_grads[2], tf.ones([2, 2]) * true_grads[2]) + self.assertAllCloseNested(out_values, tf.ones([2, 2]) * true_values) + self.assertAllCloseNested(out_grads[0], tf.ones([2, 2]) * true_grads[0]) + self.assertAllCloseNested(out_grads[1], tf.ones([2, 2]) * true_grads[1]) + self.assertAllCloseNested(out_grads[2], tf.ones([2, 2]) * true_grads[2]) def test_nested_shard_axes(self): if not JAX_MODE: @@ -795,9 +794,9 @@ def true_log_prob(x, data): true_values, true_grads = self.evaluate( gradient.value_and_gradient(true_log_prob, (x, data))) - self.assertAllEqualNested(out_values, tf.ones([2, 2]) * true_values) - self.assertAllEqualNested(out_grads[0], tf.ones([2, 2]) * true_grads[0]) - self.assertAllEqualNested(out_grads[1], tf.ones([2, 2]) * true_grads[1]) + self.assertAllCloseNested(out_values, tf.ones([2, 2]) * true_values) + self.assertAllCloseNested(out_grads[0], tf.ones([2, 2]) * true_grads[0]) + self.assertAllCloseNested(out_grads[1], tf.ones([2, 2]) * true_grads[1]) def test_gradient_is_correctly_reduced_with_multiple_axes(self): if not JAX_MODE: @@ -850,11 +849,11 @@ def true_log_prob(x, y, z): self.assertAllClose( out_values, tf.ones([2, 2]) * true_values, rtol=1e-6, atol=1e-6) - self.assertAllEqualNested(out_grads[0], + self.assertAllCloseNested(out_grads[0], tf.ones([2, 2]) * true_grads[0][:, None]) - self.assertAllEqualNested(out_grads[1], + self.assertAllCloseNested(out_grads[1], tf.ones([2, 2]) * true_grads[1][None]) - self.assertAllEqualNested(out_grads[2], tf.ones([2, 2]) * true_grads[2]) + self.assertAllCloseNested(out_grads[2], tf.ones([2, 2]) * true_grads[2]) if __name__ == '__main__':