From 0fc3b0c82bf9cc7963e8c4c8206c63a6d9480e03 Mon Sep 17 00:00:00 2001 From: emilyaf Date: Thu, 7 Sep 2023 20:37:28 -0700 Subject: [PATCH] Fix bug in pytree flattening of LinearOperatorLowRankUpdate in the JAX backend. PiperOrigin-RevId: 563623848 --- tensorflow_probability/python/internal/backend/numpy/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/linalg.py b/tensorflow_probability/python/internal/backend/numpy/linalg.py index 4e51bb3941..3b15d1d947 100644 --- a/tensorflow_probability/python/internal/backend/numpy/linalg.py +++ b/tensorflow_probability/python/internal/backend/numpy/linalg.py @@ -65,7 +65,8 @@ def register_pytrees(env): 'LinearOperatorScaledIdentity': ('multiplier',), 'LinearOperatorInversion': ('operator',), 'LinearOperatorKronecker': ('operators',), - 'LinearOperatorLowRankUpdate': ('base_operator', 'diag_update'), + 'LinearOperatorLowRankUpdate': ( + 'base_operator', 'diag_update', 'u', 'v'), 'LinearOperatorLowerTriangular': ('tril',), 'LinearOperatorPermutation': ('perm',), 'LinearOperatorToeplitz': ('col', 'row'),