You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to define some array shardings in jax. Typically I will use NamedSharding but I am working on a sharding scheme that I presume not captured by NamedSharding.
As an example, say I have a (8 x 8) array and (2, 2) device mesh. I want to shard the array such that if it is reshaped into (2 x 2 x 2 x 8), the 2-nd and the 4-th dimension will be sharded on the 1st and the 2nd dimensions of the mesh, except the sharding should work with the original 2d array.
I wonder what is an elegant / jax-ish way to do it?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am trying to define some array shardings in jax. Typically I will use
NamedSharding
but I am working on a sharding scheme that I presume not captured by NamedSharding.As an example, say I have a (8 x 8) array and (2, 2) device mesh. I want to shard the array such that if it is reshaped into (2 x 2 x 2 x 8), the 2-nd and the 4-th dimension will be sharded on the 1st and the 2nd dimensions of the mesh, except the sharding should work with the original 2d array.
I wonder what is an elegant / jax-ish way to do it?
Beta Was this translation helpful? Give feedback.
All reactions