Skip to content

Commit

Permalink
Added scatter_sub_p
Browse files Browse the repository at this point in the history
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 679670877
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 2, 2024
1 parent 816947b commit b830656
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 35 deletions.
179 changes: 155 additions & 24 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,67 @@ def scatter_add(
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=GatherScatterMode.from_any(mode))


def scatter_sub(
operand: ArrayLike,
scatter_indices: ArrayLike,
updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | GatherScatterMode | None = None,
) -> Array:
"""Scatter-sub operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
subtraction is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how
dimensions of `operand`, `start_indices`, `updates` and the output relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve
performance on some backends. JAX does not check this promise: if the
updated elements overlap when ``unique_indices`` is ``True`` the behavior
is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when set to
'fill' or 'drop' out-of-bounds updates are dropped. The behavior for
out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(
lax.sub, lax._abstractify(lax._const(operand, 0))
)
return scatter_sub_p.bind(
operand,
scatter_indices,
updates,
update_jaxpr=jaxpr,
update_consts=consts,
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=GatherScatterMode.from_any(mode),
)


def scatter_mul(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
Expand Down Expand Up @@ -1991,32 +2052,66 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
upper_bound)

def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
mode):

def _scatter_addsub_jvp(
prim,
primals,
tangents,
*,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted,
unique_indices,
mode,
):
operand, indices, updates = primals
g_operand, g_indices, g_updates = tangents
del g_indices # ignored
val_out = scatter_add_p.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
val_out = prim.bind(
operand,
indices,
updates,
update_jaxpr=update_jaxpr,
update_consts=update_consts,
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
)
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
tangent_out = scatter_add_p.bind(
g_operand, indices, g_updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
tangent_out = prim.bind(
g_operand,
indices,
g_updates,
update_jaxpr=update_jaxpr,
update_consts=update_consts,
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
)
return val_out, tangent_out

def _scatter_add_transpose_rule(t, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):

def _scatter_addsub_transpose_rule(
prim,
t,
operand,
indices,
updates,
*,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted,
unique_indices,
mode,
):
assert not ad.is_undefined_primal(indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
Expand Down Expand Up @@ -2045,6 +2140,8 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *,
pos += 1
update_t = gather(t, indices, dimension_numbers=gather_dnums,
slice_sizes=slice_sizes, mode=mode, fill_value=0)
if prim is scatter_sub_p:
update_t = lax.neg(update_t)
return [operand_t, None, update_t]

def _scatter_mul_transpose_rule(t, operand, indices, updates, *,
Expand Down Expand Up @@ -2140,11 +2237,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
weak_type_rule=_argnum_weak_type(0))
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p)
ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p)
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add_p))

scatter_sub_p = standard_primitive(
_scatter_shape_rule,
_scatter_dtype_rule,
"scatter-sub",
weak_type_rule=_argnum_weak_type(0),
)
ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p)
ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p)
batching.primitive_batchers[scatter_sub_p] = partial(
_scatter_batching_rule, scatter_sub_p
)

scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
weak_type_rule=_argnum_weak_type(0))
Expand Down Expand Up @@ -2513,16 +2622,29 @@ def _scatter_lower(ctx, operand, indices, updates, *,

mlir.register_lowering(scatter_p, _scatter_lower)
mlir.register_lowering(scatter_add_p, _scatter_lower)
mlir.register_lowering(scatter_sub_p, _scatter_lower)
mlir.register_lowering(scatter_mul_p, _scatter_lower)
mlir.register_lowering(scatter_min_p, _scatter_lower)
mlir.register_lowering(scatter_max_p, _scatter_lower)


def _real_dtype(dtype): return np.finfo(dtype).dtype

def _scatter_add_lower_gpu(ctx, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):

def _scatter_addsub_lower_gpu(
ctx,
operand,
indices,
updates,
*,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted,
unique_indices,
mode,
reduce_op,
):
operand_aval_in, _, updates_aval_in = ctx.avals_in
if operand_aval_in.dtype != np.complex128:
return _scatter_lower(ctx, operand, indices, updates,
Expand Down Expand Up @@ -2566,15 +2688,24 @@ def _scatter(operand_part, updates_part):
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
add = hlo.AddOp(*reducer.arguments).result
hlo.return_([add])
hlo.return_([reduce_op(*reducer.arguments).result])
return scatter.result

real = _scatter(hlo.real(operand), hlo.real(updates))
imag = _scatter(hlo.imag(operand), hlo.imag(updates))
return [hlo.complex(real, imag)]

mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")

mlir.register_lowering(
scatter_add_p,
partial(_scatter_addsub_lower_gpu, reduce_op=hlo.AddOp),
platform="gpu",
)
mlir.register_lowering(
scatter_sub_p,
partial(_scatter_addsub_lower_gpu, reduce_op=hlo.SubtractOp),
platform="gpu",
)


def _dynamic_slice_indices(
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ class _IndexUpdateHelper:
============================== ================================
``x = x.at[idx].set(y)`` ``x[idx] = y``
``x = x.at[idx].add(y)`` ``x[idx] += y``
``x = x.at[idx].subtract(y)`` ``x[idx] -= y``
``x = x.at[idx].multiply(y)`` ``x[idx] *= y``
``x = x.at[idx].divide(y)`` ``x[idx] /= y``
``x = x.at[idx].power(y)`` ``x[idx] **= y``
Expand Down Expand Up @@ -826,6 +827,20 @@ def add(self, values, *, indices_are_sorted=False, unique_indices=False,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)

def subtract(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] -= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] -= y``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_sub,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)

def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] *= y``.
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3040,6 +3040,7 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
tf_impl_with_avals[lax.scatter_max_p] = _scatter
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
tf_impl_with_avals[lax.scatter_add_p] = _scatter
tf_impl_with_avals[lax.scatter_sub_p] = _scatter


def _cond(
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@
scatter_mul as scatter_mul,
scatter_mul_p as scatter_mul_p,
scatter_p as scatter_p,
scatter_sub as scatter_sub,
scatter_sub_p as scatter_sub_p,
slice as slice,
slice_in_dim as slice_in_dim,
slice_p as slice_p,
Expand Down
22 changes: 13 additions & 9 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ def _can_cast(from_, to):


def _compatible_dtypes(op, dtype, inexact=False):
if op == UpdateOps.ADD:
if op == UpdateOps.ADD or op == UpdateOps.SUB:
return [dtype]
elif inexact:
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
Expand All @@ -1263,17 +1263,19 @@ def _compatible_dtypes(op, dtype, inexact=False):
class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
MUL = 2
DIV = 3
POW = 4
MIN = 5
MAX = 6
SUB = 2
MUL = 3
DIV = 4
POW = 5
MIN = 6
MAX = 7

def np_fn(op, indexer, x, y):
x = x.copy()
x[indexer] = {
UpdateOps.UPDATE: lambda: y,
UpdateOps.ADD: lambda: x[indexer] + y,
UpdateOps.SUB: lambda: x[indexer] - y,
UpdateOps.MUL: lambda: x[indexer] * y,
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
lambda: x[indexer] / y.astype(x.dtype)),
Expand All @@ -1290,6 +1292,7 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False,
return {
UpdateOps.UPDATE: x.at[indexer].set,
UpdateOps.ADD: x.at[indexer].add,
UpdateOps.SUB: x.at[indexer].subtract,
UpdateOps.MUL: x.at[indexer].multiply,
UpdateOps.DIV: x.at[indexer].divide,
UpdateOps.POW: x.at[indexer].power,
Expand Down Expand Up @@ -1420,7 +1423,7 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape,
for update_shape in _broadcastable_shapes(index_shape)
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
Expand All @@ -1447,8 +1450,9 @@ def testStaticIndexingGrads(self, name, shape, dtype, update_shape,
],
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
for op in (
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
else [UpdateOps.ADD])
[UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
if unique_indices
else [UpdateOps.ADD, UpdateOps.SUB])
for dtype in float_dtypes
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
],
Expand Down
5 changes: 3 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,14 +2782,15 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape,
]],
dtype=lax_test_util.inexact_dtypes,
mode=["clip", "fill", None],
op=[lax.scatter_add, lax.scatter_sub],
)
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode):
def testScatterAddSub(self, arg_shape, dtype, idxs, update_shape, dnums, mode, op):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter_add, dimension_numbers=dnums, mode=mode)
fun = partial(op, dimension_numbers=dnums, mode=mode)
self._CompileAndCheck(fun, args_maker)

@jtu.sample_product(
Expand Down

0 comments on commit b830656

Please sign in to comment.