Jax Sharding Issue with Axis_Types #26818
Replies: 2 comments 9 replies
-
Can you paste the full stack trace and a minimal repro? From the information you have provided, the error doesn't make sense to me. |
Beta Was this translation helpful? Give feedback.
-
I was able to get the bug to go away (although I didn't really get to the route of the problem), by adding Many thanks for your help! |
Beta Was this translation helpful? Give feedback.
-
using
mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains")
, I get an error:ValueError: Context mesh AbstractMesh('chains': 1, axis_types={Manual: ('chains',)}) should match the mesh of sharding AbstractMesh('chains': 1, axis_types={Auto: ('chains',)}) passed to broadcast_in_dim.
I can provide more details about exactly how I'm using the sharding, but I first wanted to check if this was easily resolved with a simple change. For example, I tried manually specifying the axis_types, but I don't know where to import
Manual
from.Beta Was this translation helpful? Give feedback.
All reactions