Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.polyfit with cov=True only returns one covariance matrix for multiple right hand sides #24073

Open
scottstanie opened this issue Oct 2, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@scottstanie
Copy link

Description

The C return for jax.numpy.polyfit when cov=True, full=False is always (n, n, 1), regardless of the shape of y. This is unexpected, and doesn't match the numpy.polyfit return shapes.

import numpy as np
import jax.numpy as jnp

np.random.seed(1)
x = np.arange(10).astype('float32')
num_rhs = 3
y = np.random.randn(x.shape[0], num_rhs)

# numpy: returns 3 covariance matrices for 3 different rhs
print(np.polyfit(x, y, 1, cov=True)[1])

# jax: returns 1 matrix, matching the first one returned by numpy, leaving off the rest
print(jnp.polyfit(x, y, 1, cov=True)[1])
>>> print(np.polyfit(x, y, 1, cov=True)[1])
[[[ 0.01583048  0.00850339  0.0145006 ]
  [-0.07123718 -0.03826527 -0.0652527 ]]

 [[-0.07123718 -0.03826527 -0.0652527 ]
  [ 0.45116882  0.24234668  0.41326709]]]

>>> print(jnp.polyfit(x, y, 1, cov=True)[1])
[[[ 0.01583048]
  [-0.07123714]]

 [[-0.07123714]
  [ 0.45116854]]]

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.34.dev20241002+e212c7733
jaxlib: 0.4.33
numpy:  2.1.1
python: 3.12.6 | packaged by conda-forge | (main, Sep 30 2024, 17:55:20) [Clang 17.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MT-317120', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:49:39 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_ARM64_T6000', machine='arm64')

@scottstanie scottstanie added the bug Something isn't working label Oct 2, 2024
scottstanie added a commit to scottstanie/jax that referenced this issue Oct 2, 2024
Addresses jax-ml#24073

Taking the first residual from `resids` means that only the first
set of coefficients would get a covariance matrix.
This moves the line to make a `(1,)` shape array into an `int`
to the branch where there is only one `rhs`.
@dfm dfm self-assigned this Oct 2, 2024
scottstanie added a commit to scottstanie/jax that referenced this issue Oct 2, 2024
Addresses jax-ml#24073

Taking the first residual from `resids` means that only the first
set of coefficients would get a covariance matrix.
This moves the line to make a `(1,)` shape array into an `int`
to the branch where there is only one `rhs`.

Fixes typos in `polyfit` docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants