Skip to content

Commit

Permalink
Add basic support for shard_map to oryx
Browse files Browse the repository at this point in the history
This doesn't allow sowing inside shard_map, but does allow sowing if it occurs before or after the shard_map. It adds tests for all these three cases, ensuring errors are thrown if sows are inside shard_map

PiperOrigin-RevId: 668673795
  • Loading branch information
The oryx Authors committed Aug 29, 2024
1 parent 21722fa commit 93fc9ef
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 1 deletion.
20 changes: 20 additions & 0 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def f(x):
from jax._src import pjit
from jax._src import sharding_impls
from jax._src.lax import control_flow as lcf
from jax.experimental import shard_map
import jax.extend.linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
Expand Down Expand Up @@ -443,6 +444,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
return context.process_custom_jvp_call(self, primitive, fun, jvp, tracers,
symbolic_zeros=symbolic_zeros)

def process_shard_map(self, primitive, f, tracers, **params):
out_flat = primitive.bind(f, *[t.val for t in tracers], **params)
out_tracers = map(self.pure, out_flat)
return out_tracers

def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
context = trace_util.get_dynamic_context(self)
return context.post_process_custom_jvp_call(self, out_tracers, jvp_was_run)
Expand Down Expand Up @@ -1704,3 +1710,17 @@ def harvest(f,
kwargs = dict(
tag=tag, allowlist=allowlist, blocklist=blocklist, exclusive=exclusive)
return call_and_reap(plant(f, **kwargs), **kwargs)


# Handle shard_map
@shard_map.register_check(sow_p)
def _sow_check(mesh, *in_rep, name, tag, mode, tree):
del mesh, name, tag, mode, tree
return in_rep[0] # TODO(conmy): does this limit use to one output only?


@shard_map.register_rewrite(sow_p)
def _sow_rewrite(mesh, in_rep, *args, name, tag, mode, tree):
raise ValueError(
'Detected sow calls inside a shard_map. This is not currently supported.'
)
149 changes: 148 additions & 1 deletion oryx/core/interpreters/harvest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from jax import config
from jax import lax
from jax._src import pjit
from jax.experimental import mesh_utils
from jax.experimental import shard_map
import jax.numpy as jnp
import numpy as np

from oryx.core import trace_util
from oryx.core.interpreters import harvest
from oryx.internal import test_util
Expand Down Expand Up @@ -1019,6 +1020,152 @@ def branch3(x):
self.assertEqual(out, 8.)


class ShardMapTest(test_util.TestCase):

def setUp(self):
super().setUp()
self.devices = mesh_utils.create_device_mesh((1, 2))
self.mesh = jax.sharding.Mesh(self.devices, axis_names=('x', 'y'))
self.a = jnp.arange(8 * 16.0).reshape(8, 16)
self.b = jnp.arange(16 * 4.0).reshape(16, 4)

def _f(a, b):
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None),
),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block

shmapped_val = oryx_shmap_matmul(a, b)
sowed_val = sow(shmapped_val, name='shmapped_val', tag='intermediate')
return 2.0 * sowed_val

self.f = _f

def _f_with_sow_before_shmap(a, b):
sowed_val = sow(a, name='a', tag='intermediate')

@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None),
),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partial_sum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partial_sum, 'y')
# c_block: f32[2, 4]
return c_block

shmapped_val = oryx_shmap_matmul(sowed_val, b)
return 2.0 * shmapped_val

self.f_with_sow_before_shmap = _f_with_sow_before_shmap

def _f_with_sow_inside_shmap(a, b):
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
jax.sharding.PartitionSpec('x', 'y'),
jax.sharding.PartitionSpec('y', None),
),
out_specs=jax.sharding.PartitionSpec('x', None),
)
def oryx_shmap_matmul(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partial_sum = jnp.dot(a_block, b_block)
c_block = sow(
jax.lax.psum(c_partial_sum, 'y'), name='c_block', tag='intermediate'
)
# c_block: f32[2, 4]
return c_block

return 2.0 * oryx_shmap_matmul(a, b)

self.f_with_sow_inside_shmap = _f_with_sow_inside_shmap

def test_reap(self):
reap_dict = reap(self.f, tag='intermediate')(self.a, self.b)
self.assertEqual(
list(reap_dict.keys()), ['shmapped_val'], msg='Wrong reap dict keys'
)

self.assertFalse(
np.isclose(
reap_dict['shmapped_val'], 2.0 * jnp.dot(self.a, self.b)
).any(),
msg=(
'Reaped value is close to 2.0 * matmul but that'
' should be the output of the function, not the'
' intermediate reaped value.'
),
)
np.testing.assert_allclose(
reap_dict['shmapped_val'], jnp.dot(self.a, self.b)
)

def test_plant(self):
shampped_val_for_planting = 0.5 * jnp.dot(self.a, self.b)
f_output_planted = plant(self.f, tag='intermediate')(
dict(shmapped_val=shampped_val_for_planting), self.a, self.b
)
np.testing.assert_allclose(
f_output_planted, 2.0 * shampped_val_for_planting
)

def test_reap_before_shmap(self):
reap_dict = reap(self.f_with_sow_before_shmap, tag='intermediate')(
self.a, self.b
)
self.assertEqual(list(reap_dict.keys()), ['a'], msg='Wrong reap dict keys')
np.testing.assert_allclose(reap_dict['a'], self.a)

def test_plant_before_shmap(self):
a_val_for_planting = 0.5 * self.a
f_output_planted = plant(self.f_with_sow_before_shmap, tag='intermediate')(
dict(a=a_val_for_planting), self.a, self.b
)
np.testing.assert_allclose(
f_output_planted, 2.0 * jnp.dot(a_val_for_planting, self.b)
)

def test_reap_inside_shmap_fails(self):
with self.assertRaisesRegex(
ValueError,
'Detected sow calls inside a shard_map.'
' This is not currently supported.',
):
reap(self.f_with_sow_inside_shmap, tag='intermediate')(self.a, self.b)

def test_plant_inside_shmap_fails(self):
with self.assertRaisesRegex(
ValueError,
'Detected sow calls inside a shard_map.'
' This is not currently supported.',
):
plant(self.f_with_sow_inside_shmap, tag='intermediate')(
dict(c_block=15.0 * jnp.dot(self.a, self.b)), self.a, self.b
)


if __name__ == '__main__':
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
absltest.main()

0 comments on commit 93fc9ef

Please sign in to comment.