Skip to content

HELP! PRNGKey with shard_map not working #22862

Answered by AshishKumar4
AshishKumar4 asked this question in Q&A
Discussion options

You must be logged in to vote

Found something. This seems to work:

out2 = jax.jit(shard_map(func, mesh=mesh, in_specs=(P('i'), P(), P()), out_specs=(P('i'), P(), P())))(data, rngs, state)

Basically, If I don't give the axis name at all, then no need to replicate anything and no reshaping etc required. And I am still able to use collectives inside. This works on multi-host. Just want to confirm from you the caveats of this method. Is this alright?
Once again, thank you soo much for this!

EDIT:
should be this actually:

out2 = jax.jit(shard_map(func, mesh=mesh, in_specs=(P('i'), P()), out_specs=(P('i'), P())))(data, rngs, state)

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
4 replies
@AshishKumar4
Comment options

@AshishKumar4
Comment options

Answer selected by AshishKumar4
@jakevdp
Comment options

@AshishKumar4
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants