From ebc43b5c1b9a581a10f6a5ff1aad0b89680806a1 Mon Sep 17 00:00:00 2001 From: phawkins Date: Tue, 1 Oct 2024 10:37:18 -0700 Subject: [PATCH] [numpy] Fix test failures under NumPy 2.1. PiperOrigin-RevId: 681078015 --- tensorflow_probability/python/internal/test_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index ef10557afd..1e44f811cb 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -613,7 +613,7 @@ def assertAllInRange(self, 'The value of %s does not have an ordered numeric type, instead it ' 'has type: %s' % (target, target.dtype)) - nan_subscripts = np.where(np.isnan(target)) + nan_subscripts = np.where(np.atleast_1d(np.isnan(target))) if np.size(nan_subscripts): raise AssertionError( '%d of the %d element(s) are NaN. ' @@ -631,7 +631,7 @@ def assertAllInRange(self, violations, np.greater_equal(target, upper_bound) if open_upper_bound else np.greater(target, upper_bound)) - violation_subscripts = np.where(violations) + violation_subscripts = np.where(np.atleast_1d(violations)) if np.size(violation_subscripts): raise AssertionError( '%d of the %d element(s) are outside the range %s. ' %