diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index f787ce5..6d4d4ca 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -156,6 +156,7 @@ def f(x): from jax._src import pjit from jax._src import sharding_impls from jax._src.lax import control_flow as lcf +from jax.experimental import shard_map import jax.extend.linear_util as lu from jax.interpreters import ad from jax.interpreters import batching @@ -443,6 +444,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, return context.process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros=symbolic_zeros) + def process_shard_map(self, primitive, f, tracers, **params): + out_flat = primitive.bind(f, *[t.val for t in tracers], **params) + out_tracers = map(self.pure, out_flat) + return out_tracers + def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): context = trace_util.get_dynamic_context(self) return context.post_process_custom_jvp_call(self, out_tracers, jvp_was_run) @@ -1704,3 +1710,17 @@ def harvest(f, kwargs = dict( tag=tag, allowlist=allowlist, blocklist=blocklist, exclusive=exclusive) return call_and_reap(plant(f, **kwargs), **kwargs) + + +# Handle shard_map +@shard_map.register_check(sow_p) +def _sow_check(mesh, *in_rep, name, tag, mode, tree): + del mesh, name, tag, mode, tree + return in_rep[0] # TODO(conmy): does this limit use to one output only? + + +@shard_map.register_rewrite(sow_p) +def _sow_rewrite(mesh, in_rep, *args, name, tag, mode, tree): + raise ValueError( + 'Detected sow calls inside a shard_map. This is not currently supported.' + ) diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index f020a5a..1f22242 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -38,9 +38,10 @@ from jax import config from jax import lax from jax._src import pjit +from jax.experimental import mesh_utils +from jax.experimental import shard_map import jax.numpy as jnp import numpy as np - from oryx.core import trace_util from oryx.core.interpreters import harvest from oryx.internal import test_util @@ -1019,6 +1020,152 @@ def branch3(x): self.assertEqual(out, 8.) +class ShardMapTest(test_util.TestCase): + + def setUp(self): + super().setUp() + self.devices = mesh_utils.create_device_mesh((1, 2)) + self.mesh = jax.sharding.Mesh(self.devices, axis_names=('x', 'y')) + self.a = jnp.arange(8 * 16.0).reshape(8, 16) + self.b = jnp.arange(16 * 4.0).reshape(16, 4) + + def _f(a, b): + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=( + jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None), + ), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partialsum = jnp.dot(a_block, b_block) + c_block = jax.lax.psum(c_partialsum, 'y') + # c_block: f32[2, 4] + return c_block + + shmapped_val = oryx_shmap_matmul(a, b) + sowed_val = sow(shmapped_val, name='shmapped_val', tag='intermediate') + return 2.0 * sowed_val + + self.f = _f + + def _f_with_sow_before_shmap(a, b): + sowed_val = sow(a, name='a', tag='intermediate') + + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=( + jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None), + ), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partial_sum = jnp.dot(a_block, b_block) + c_block = jax.lax.psum(c_partial_sum, 'y') + # c_block: f32[2, 4] + return c_block + + shmapped_val = oryx_shmap_matmul(sowed_val, b) + return 2.0 * shmapped_val + + self.f_with_sow_before_shmap = _f_with_sow_before_shmap + + def _f_with_sow_inside_shmap(a, b): + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=( + jax.sharding.PartitionSpec('x', 'y'), + jax.sharding.PartitionSpec('y', None), + ), + out_specs=jax.sharding.PartitionSpec('x', None), + ) + def oryx_shmap_matmul(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partial_sum = jnp.dot(a_block, b_block) + c_block = sow( + jax.lax.psum(c_partial_sum, 'y'), name='c_block', tag='intermediate' + ) + # c_block: f32[2, 4] + return c_block + + return 2.0 * oryx_shmap_matmul(a, b) + + self.f_with_sow_inside_shmap = _f_with_sow_inside_shmap + + def test_reap(self): + reap_dict = reap(self.f, tag='intermediate')(self.a, self.b) + self.assertEqual( + list(reap_dict.keys()), ['shmapped_val'], msg='Wrong reap dict keys' + ) + + self.assertFalse( + np.isclose( + reap_dict['shmapped_val'], 2.0 * jnp.dot(self.a, self.b) + ).any(), + msg=( + 'Reaped value is close to 2.0 * matmul but that' + ' should be the output of the function, not the' + ' intermediate reaped value.' + ), + ) + np.testing.assert_allclose( + reap_dict['shmapped_val'], jnp.dot(self.a, self.b) + ) + + def test_plant(self): + shampped_val_for_planting = 0.5 * jnp.dot(self.a, self.b) + f_output_planted = plant(self.f, tag='intermediate')( + dict(shmapped_val=shampped_val_for_planting), self.a, self.b + ) + np.testing.assert_allclose( + f_output_planted, 2.0 * shampped_val_for_planting + ) + + def test_reap_before_shmap(self): + reap_dict = reap(self.f_with_sow_before_shmap, tag='intermediate')( + self.a, self.b + ) + self.assertEqual(list(reap_dict.keys()), ['a'], msg='Wrong reap dict keys') + np.testing.assert_allclose(reap_dict['a'], self.a) + + def test_plant_before_shmap(self): + a_val_for_planting = 0.5 * self.a + f_output_planted = plant(self.f_with_sow_before_shmap, tag='intermediate')( + dict(a=a_val_for_planting), self.a, self.b + ) + np.testing.assert_allclose( + f_output_planted, 2.0 * jnp.dot(a_val_for_planting, self.b) + ) + + def test_reap_inside_shmap_fails(self): + with self.assertRaisesRegex( + ValueError, + 'Detected sow calls inside a shard_map.' + ' This is not currently supported.', + ): + reap(self.f_with_sow_inside_shmap, tag='intermediate')(self.a, self.b) + + def test_plant_inside_shmap_fails(self): + with self.assertRaisesRegex( + ValueError, + 'Detected sow calls inside a shard_map.' + ' This is not currently supported.', + ): + plant(self.f_with_sow_inside_shmap, tag='intermediate')( + dict(c_block=15.0 * jnp.dot(self.a, self.b)), self.a, self.b + ) + + if __name__ == '__main__': os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' absltest.main()