From fef705ac12478515c4d8afd7a40e93c6ada105ba Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 27 Sep 2024 11:42:47 -0700 Subject: [PATCH] Added `scatter_sub_p` The new primitive is used for in-place subtract and update. Closes #23933 PiperOrigin-RevId: 679670877 --- jax/_src/lax/slicing.py | 179 ++++++++++++++++++++++++++---- jax/_src/numpy/array_methods.py | 15 +++ jax/experimental/jax2tf/jax2tf.py | 1 + jax/lax/__init__.py | 2 + tests/lax_numpy_indexing_test.py | 22 ++-- tests/lax_test.py | 5 +- 6 files changed, 189 insertions(+), 35 deletions(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 60dfa0e1b3d2..f1921395d92b 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -426,6 +426,67 @@ def scatter_add( indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) + +def scatter_sub( + operand: ArrayLike, + scatter_indices: ArrayLike, + updates: ArrayLike, + dimension_numbers: ScatterDimensionNumbers, + *, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | GatherScatterMode | None = None, +) -> Array: + """Scatter-sub operator. + + Wraps `XLA's Scatter operator + `_, where + subtraction is used to combine updates and values from `operand`. + + The semantics of scatter are complicated, and its API might change in the + future. For most use cases, you should prefer the + :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses + the familiar NumPy indexing syntax. + + Args: + operand: an array to which the scatter should be applied + scatter_indices: an array that gives the indices in `operand` to which each + update in `updates` should be applied. + updates: the updates that should be scattered onto `operand`. + dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how + dimensions of `operand`, `start_indices`, `updates` and the output relate. + indices_are_sorted: whether `scatter_indices` is known to be sorted. If + true, may improve performance on some backends. + unique_indices: whether the elements to be updated in ``operand`` are + guaranteed to not overlap with each other. If true, may improve + performance on some backends. JAX does not check this promise: if the + updated elements overlap when ``unique_indices`` is ``True`` the behavior + is undefined. + mode: how to handle indices that are out of bounds: when set to 'clip', + indices are clamped so that the slice is within bounds, and when set to + 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for + out-of-bounds indices when set to 'promise_in_bounds' is + implementation-defined. + + Returns: + An array containing the sum of `operand` and the scattered updates. + """ + jaxpr, consts = lax._reduction_jaxpr( + lax.sub, lax._abstractify(lax._const(operand, 0)) + ) + return scatter_sub_p.bind( + operand, + scatter_indices, + updates, + update_jaxpr=jaxpr, + update_consts=consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=GatherScatterMode.from_any(mode), + ) + + def scatter_mul( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -1988,32 +2049,66 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64), upper_bound) -def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, - dimension_numbers, indices_are_sorted, unique_indices, - mode): + +def _scatter_addsub_jvp( + prim, + primals, + tangents, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, +): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents del g_indices # ignored - val_out = scatter_add_p.bind( - operand, indices, updates, update_jaxpr=update_jaxpr, - update_consts=update_consts, dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + val_out = prim.bind( + operand, + indices, + updates, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) - tangent_out = scatter_add_p.bind( - g_operand, indices, g_updates, update_jaxpr=update_jaxpr, - update_consts=update_consts, dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + tangent_out = prim.bind( + g_operand, + indices, + g_updates, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ) return val_out, tangent_out -def _scatter_add_transpose_rule(t, operand, indices, updates, *, - update_jaxpr, update_consts, dimension_numbers, - indices_are_sorted, unique_indices, mode): + +def _scatter_addsub_transpose_rule( + prim, + t, + operand, + indices, + updates, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, +): assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape @@ -2042,6 +2137,8 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *, pos += 1 update_t = gather(t, indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes, mode=mode, fill_value=0) + if prim is scatter_sub_p: + update_t = lax.neg(update_t) return [operand_t, None, update_t] def _scatter_mul_transpose_rule(t, operand, indices, updates, *, @@ -2137,11 +2234,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', weak_type_rule=_argnum_weak_type(0)) -ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp -ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule +ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) +ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( partial(_scatter_batching_rule, scatter_add_p)) +scatter_sub_p = standard_primitive( + _scatter_shape_rule, + _scatter_dtype_rule, + "scatter-sub", + weak_type_rule=_argnum_weak_type(0), +) +ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) +ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) +batching.primitive_batchers[scatter_sub_p] = partial( + _scatter_batching_rule, scatter_sub_p +) + scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', weak_type_rule=_argnum_weak_type(0)) @@ -2510,6 +2619,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, mlir.register_lowering(scatter_p, _scatter_lower) mlir.register_lowering(scatter_add_p, _scatter_lower) +mlir.register_lowering(scatter_sub_p, _scatter_lower) mlir.register_lowering(scatter_mul_p, _scatter_lower) mlir.register_lowering(scatter_min_p, _scatter_lower) mlir.register_lowering(scatter_max_p, _scatter_lower) @@ -2517,9 +2627,21 @@ def _scatter_lower(ctx, operand, indices, updates, *, def _real_dtype(dtype): return np.finfo(dtype).dtype -def _scatter_add_lower_gpu(ctx, operand, indices, updates, - *, update_jaxpr, update_consts, dimension_numbers, - indices_are_sorted, unique_indices, mode): + +def _scatter_addsub_lower_gpu( + ctx, + operand, + indices, + updates, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, + reduce_op, +): operand_aval_in, _, updates_aval_in = ctx.avals_in if operand_aval_in.dtype != np.complex128: return _scatter_lower(ctx, operand, indices, updates, @@ -2563,15 +2685,24 @@ def _scatter(operand_part, updates_part): scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype)) reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): - add = hlo.AddOp(*reducer.arguments).result - hlo.return_([add]) + hlo.return_([reduce_op(*reducer.arguments).result]) return scatter.result real = _scatter(hlo.real(operand), hlo.real(updates)) imag = _scatter(hlo.imag(operand), hlo.imag(updates)) return [hlo.complex(real, imag)] -mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu") + +mlir.register_lowering( + scatter_add_p, + partial(_scatter_addsub_lower_gpu, reduce_op=hlo.AddOp), + platform="gpu", +) +mlir.register_lowering( + scatter_sub_p, + partial(_scatter_addsub_lower_gpu, reduce_op=hlo.SubtractOp), + platform="gpu", +) def _dynamic_slice_indices( diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 95d681cad8e5..24b8d315d8ac 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -661,6 +661,7 @@ class _IndexUpdateHelper: ============================== ================================ ``x = x.at[idx].set(y)`` ``x[idx] = y`` ``x = x.at[idx].add(y)`` ``x[idx] += y`` + ``x = x.at[idx].subtract(y)`` ``x[idx] -= y`` ``x = x.at[idx].multiply(y)`` ``x[idx] *= y`` ``x = x.at[idx].divide(y)`` ``x[idx] /= y`` ``x = x.at[idx].power(y)`` ``x[idx] **= y`` @@ -826,6 +827,20 @@ def add(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) + def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] -= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] -= y``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_sub, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] *= y``. diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index f01a3ab7a036..d320a531e511 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3040,6 +3040,7 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: tf_impl_with_avals[lax.scatter_max_p] = _scatter tf_impl_with_avals[lax.scatter_mul_p] = _scatter tf_impl_with_avals[lax.scatter_add_p] = _scatter +tf_impl_with_avals[lax.scatter_sub_p] = _scatter def _cond( diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index bb72abb2ec32..420d268d9526 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -280,6 +280,8 @@ scatter_mul as scatter_mul, scatter_mul_p as scatter_mul_p, scatter_p as scatter_p, + scatter_sub as scatter_sub, + scatter_sub_p as scatter_sub_p, slice as slice, slice_in_dim as slice_in_dim, slice_p as slice_p, diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index d58a5c2c3866..392af2688c1d 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1252,7 +1252,7 @@ def _can_cast(from_, to): def _compatible_dtypes(op, dtype, inexact=False): - if op == UpdateOps.ADD: + if op == UpdateOps.ADD or op == UpdateOps.SUB: return [dtype] elif inexact: return [dt for dt in float_dtypes if _can_cast(dt, dtype)] @@ -1263,17 +1263,19 @@ def _compatible_dtypes(op, dtype, inexact=False): class UpdateOps(enum.Enum): UPDATE = 0 ADD = 1 - MUL = 2 - DIV = 3 - POW = 4 - MIN = 5 - MAX = 6 + SUB = 2 + MUL = 3 + DIV = 4 + POW = 5 + MIN = 6 + MAX = 7 def np_fn(op, indexer, x, y): x = x.copy() x[indexer] = { UpdateOps.UPDATE: lambda: y, UpdateOps.ADD: lambda: x[indexer] + y, + UpdateOps.SUB: lambda: x[indexer] - y, UpdateOps.MUL: lambda: x[indexer] * y, UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] / y.astype(x.dtype)), @@ -1290,6 +1292,7 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False, return { UpdateOps.UPDATE: x.at[indexer].set, UpdateOps.ADD: x.at[indexer].add, + UpdateOps.SUB: x.at[indexer].subtract, UpdateOps.MUL: x.at[indexer].multiply, UpdateOps.DIV: x.at[indexer].divide, UpdateOps.POW: x.at[indexer].power, @@ -1420,7 +1423,7 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) - for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] + for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], @@ -1447,8 +1450,9 @@ def testStaticIndexingGrads(self, name, shape, dtype, update_shape, ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in ( - [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices - else [UpdateOps.ADD]) + [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] + if unique_indices + else [UpdateOps.ADD, UpdateOps.SUB]) for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], diff --git a/tests/lax_test.py b/tests/lax_test.py index c8f3ca797903..c75058194653 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2782,14 +2782,15 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape, ]], dtype=lax_test_util.inexact_dtypes, mode=["clip", "fill", None], + op=[lax.scatter_add, lax.scatter_sub], ) - def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): + def testScatterAddSub(self, arg_shape, dtype, idxs, update_shape, dnums, mode, op): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype) args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] - fun = partial(lax.scatter_add, dimension_numbers=dnums, mode=mode) + fun = partial(op, dimension_numbers=dnums, mode=mode) self._CompileAndCheck(fun, args_maker) @jtu.sample_product(