Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .at[].subtract #23933

Open
carlosgmartin opened this issue Sep 26, 2024 · 2 comments · May be fixed by #23998
Open

Add .at[].subtract #23933

carlosgmartin opened this issue Sep 26, 2024 · 2 comments · May be fixed by #23998
Assignees
Labels
enhancement New feature or request

Comments

@carlosgmartin
Copy link
Contributor

Feature request: Add a subtract method to jax.numpy.ndarray.at.

@carlosgmartin carlosgmartin added the enhancement New feature or request label Sep 26, 2024
@superbobry
Copy link
Collaborator

This sounds reasonable, even though you can do .add(-x) for signed types.

@jakevdp wdyt?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 26, 2024

I think this is reasonable. We could do it the easy way, with something like this:

def subtract(self, x):
  return self.add(-x)

but as @superbobry points out, this would fail for unsigned types.

The full solution is a bit more work: it would involve defining a lax.scatter_sub primitive with appropriate autodiff rules, and dispatching to that. The implementation would look very similar to scatter_add:

def scatter_add(

The only really involved pieces there would be the scatter_sub_jvp and scatter_sub_transpose rules, but those could probably be implemented in terms of the existing scatter_add_jvp and scatter_add_transpose.

@superbobry superbobry self-assigned this Sep 27, 2024
copybara-service bot pushed a commit that referenced this issue Sep 30, 2024
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
@copybara-service copybara-service bot linked a pull request Sep 30, 2024 that will close this issue
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
copybara-service bot pushed a commit that referenced this issue Oct 2, 2024
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
copybara-service bot pushed a commit that referenced this issue Oct 2, 2024
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
copybara-service bot pushed a commit that referenced this issue Oct 2, 2024
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants