Skip to content

Commit

Permalink
BUG: return one covariance matrix per rhs in polyfit
Browse files Browse the repository at this point in the history
Addresses jax-ml#24073

Taking the first residual from `resids` means that only the first
set of coefficients would get a covariance matrix.
This moves the line to make a `(1,)` shape array into an `int`
to the branch where there is only one `rhs`.
  • Loading branch information
scottstanie committed Oct 2, 2024
1 parent e212c77 commit d7663c8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
>>> p, C = jnp.polyfit(x, y, 2, cov=True)
>>> p.shape, C.shape
((3, 3), (3, 3, 1))
((3, 3), (3, 3, 3))
"""
if w is None:
check_arraylike("polyfit", x, y)
Expand Down Expand Up @@ -278,8 +278,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x_arr) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y_arr.ndim == 1:
fac = fac[0] #making np.array() of shape (1,) to int
return c, Vbase * fac
else:
return c, Vbase[:, :, np.newaxis] * fac
Expand Down

0 comments on commit d7663c8

Please sign in to comment.