Skip to content

Commit

Permalink
[numpy] Fix test failures under NumPy 2.1.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681078015
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Oct 1, 2024
1 parent c6c86e7 commit ebc43b5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def assertAllInRange(self,
'The value of %s does not have an ordered numeric type, instead it '
'has type: %s' % (target, target.dtype))

nan_subscripts = np.where(np.isnan(target))
nan_subscripts = np.where(np.atleast_1d(np.isnan(target)))
if np.size(nan_subscripts):
raise AssertionError(
'%d of the %d element(s) are NaN. '
Expand All @@ -631,7 +631,7 @@ def assertAllInRange(self,
violations,
np.greater_equal(target, upper_bound)
if open_upper_bound else np.greater(target, upper_bound))
violation_subscripts = np.where(violations)
violation_subscripts = np.where(np.atleast_1d(violations))
if np.size(violation_subscripts):
raise AssertionError(
'%d of the %d element(s) are outside the range %s. ' %
Expand Down

0 comments on commit ebc43b5

Please sign in to comment.