[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
06f8a13 to
c34771d
Compare
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF )
c34771d to
351b9df
Compare
|
/te-ci JAX L1 |
1e4c3ae to
5b49ecc
Compare
|
/te-ci JAX L1 |
…16 max_token_dtype Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled() Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
4bb76dc to
9df769a
Compare
|
/te-ci JAX L1 |
| f" over [num_procs, recv_pr, H]; got spec={eo_spec}." | ||
| ) | ||
| per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) | ||
| arg_shardings = tuple(a.sharding for a in arg_infos) | ||
| out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) |
There was a problem hiding this comment.
ep_combine_fwd crashes with TypeError in SPMD when out_partition_spec=None
ep_combine_fwd has out_partition_spec=None as its default (line 916). When JAX's SPMD machinery calls EpCombinePrimitive.partition, two sites immediately fail on None:
- Line 575:
out_partition_spec[0]→TypeError: 'NoneType' object is not subscriptable - Line 577:
PartitionSpec(*out_partition_spec)→TypeError: argument after * must be an iterable, not NoneType
Since ep_combine_fwd is exported in __all__ with a None default, any user calling it in a multi-device context without specifying out_partition_spec will hit a cryptic crash at JIT compilation time. The _combine_fwd wrapper always sets this value, so the custom_vjp path is safe, but the raw primitive path is not. A guard like if out_partition_spec is None: raise ValueError("out_partition_spec must be specified in SPMD mode") at the top of partition would produce a helpful error message instead.
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: