Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 21 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 21 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR lands the JAX bindings for Expert Parallelism: XLA FFI handlers over the nvte_ep_* C API, custom_vjp-wrapped ep_dispatch / ep_combine for autograd, mesh-aware SPMD partition rules, a 13-test multi-process suite, and an end-to-end MoE example. The NCCL EP comm lifetime is managed via a weak/shared-ptr anchor pattern so compiled XLA executables keep the communicator alive without leaking it at shutdown.

  • Five XLA FFI primitives (ep_prepare, ep_dispatch, ep_combine + their backwards) are registered with FFI_CudaGraph_Traits for CUDA-graph capture; int32 topk_idx is upcast to int64 on-stream via a pre-allocated workspace buffer.
  • ep_combine_fwd (exported in __all__) defaults out_partition_spec=None, but EpCombinePrimitive.partition unconditionally subscripts and unpacks it — any SPMD caller omitting this argument will get a cryptic TypeError at JIT time rather than a clear validation error.
  • nccl_ep_enabled() in build_tools/utils.py mirrors setup.py's arch-guard logic and correctly auto-disables EP for pre-Hopper targets.

Confidence Score: 4/5

Safe to merge for users going through the ep_dispatch / ep_combine custom_vjp wrappers; the raw ep_combine_fwd primitive crashes in multi-device mode when out_partition_spec is omitted.

The ep_dispatch / ep_combine autograd wrappers always set out_partition_spec before calling the primitive, so training workflows using the public API are not affected. However, ep_combine_fwd is exported in all with out_partition_spec=None as the default, and EpCombinePrimitive.partition unconditionally subscripts and unpacks that value — a direct SPMD caller omitting the argument gets a TypeError at JIT time. Aside from this crash path, the NCCL comm lifetime management, int32-to-int64 upcast, and sharding partition rules all look correct.

transformer_engine/jax/cpp_extensions/ep.py — specifically EpCombinePrimitive.partition and the ep_combine_fwd public helper, which need a guard against out_partition_spec=None in SPMD contexts.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/ep.py Core JAX primitive definitions for EP ops (+951 lines). Contains abstract_eval, lowering, impl, partition, and shardy rules. The ep_combine_fwd public API exposes a None default for out_partition_spec that causes a cryptic TypeError in EpCombinePrimitive.partition in SPMD mode.
transformer_engine/jax/ep.py Public Python API: ep_bootstrap, custom_vjp wrappers ep_dispatch and ep_combine. Bootstrap ordering is validated (mesh resource checks, expert divisibility). _combine_fwd and _dispatch_bwd correctly pin cotangent shardings before passing to backward primitives.
transformer_engine/jax/csrc/extensions/ep.cpp XLA FFI handlers for all five EP ops. NCCL comm lifetime is managed via weak_ptr/shared_ptr EpResources anchored by EpInstanceState. int32->int64 upcast for topk_idx is done on-stream with workspace buffer. dtype assumptions are consistent with the documented float32-only router.
build_tools/utils.py Adds nccl_ep_enabled() helper that mirrors setup.py arch-guard logic, silently disabling EP for sub-90 archs and raising only when NVTE_WITH_NCCL_EP=1 is explicit.
tests/jax/test_multi_process_ep.py 13 multi-process tests covering bootstrap, prepare, dispatch/combine identity for uniform and skewed routing, custom_vjp fwd+bwd correctness, and HLO inspection for spurious XLA collectives.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User as User (JAX)
    participant EP as ep.py (custom_vjp)
    participant Prim as cpp_extensions/ep.py (Primitives)
    participant FFI as ep.cpp (XLA FFI)
    participant NCCL as NCCL EP Backend

    Note over User,NCCL: Bootstrap (once per process)
    User->>EP: ep_bootstrap(world_size, rank, ...)
    EP->>NCCL: ncclGetUniqueId (color root)
    EP->>EP: _allgather_uid (JAX collective)
    EP->>FFI: set_ep_bootstrap_params(uid, ep_size, ...)
    FFI->>NCCL: ncclCommInitRank → EpResources anchor

    Note over User,NCCL: Forward pass (per step, inside jax.jit)
    User->>EP: ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_cap)
    EP->>Prim: ep_prepare (EpPreparePrimitive)
    Prim->>FFI: EpPrepareFFI → nvte_ep_prepare
    FFI->>NCCL: ncclEpDispatch routing metadata
    Prim-->>EP: token_counts, handle_mem
    EP->>Prim: ep_dispatch_fwd (EpDispatchPrimitive)
    Prim->>FFI: EpDispatchFFI → nvte_ep_dispatch
    FFI->>NCCL: ncclEpDispatch tokens + weights
    Prim-->>EP: recv_tokens, recv_topk_weights
    EP-->>User: recv_tokens, recv_topk_weights, handle_mem, token_counts

    User->>User: Expert MLP (local computation)

    User->>EP: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
    EP->>Prim: ep_combine_fwd (EpCombinePrimitive)
    Prim->>FFI: EpCombineFFI → nvte_ep_combine
    FFI->>NCCL: ncclEpCombine scatter-sum
    Prim-->>EP: combined_tokens
    EP-->>User: combined_tokens

    Note over User,NCCL: Backward pass (custom_vjp)
    User->>EP: _dispatch_bwd(g_recv_tokens, g_recv_weights)
    EP->>Prim: ep_dispatch_bwd (EpDispatchBwdPrimitive)
    Prim->>FFI: EpDispatchBwdFFI → nvte_ep_dispatch_bwd
    Prim-->>EP: grad_tokens, grad_topk_weights

    User->>EP: _combine_bwd(g_combined)
    EP->>Prim: ep_combine_bwd (EpCombineBwdPrimitive)
    Prim->>FFI: EpCombineBwdFFI → nvte_ep_combine_bwd
    Prim-->>EP: grad_expert_out
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User as User (JAX)
    participant EP as ep.py (custom_vjp)
    participant Prim as cpp_extensions/ep.py (Primitives)
    participant FFI as ep.cpp (XLA FFI)
    participant NCCL as NCCL EP Backend

    Note over User,NCCL: Bootstrap (once per process)
    User->>EP: ep_bootstrap(world_size, rank, ...)
    EP->>NCCL: ncclGetUniqueId (color root)
    EP->>EP: _allgather_uid (JAX collective)
    EP->>FFI: set_ep_bootstrap_params(uid, ep_size, ...)
    FFI->>NCCL: ncclCommInitRank → EpResources anchor

    Note over User,NCCL: Forward pass (per step, inside jax.jit)
    User->>EP: ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_cap)
    EP->>Prim: ep_prepare (EpPreparePrimitive)
    Prim->>FFI: EpPrepareFFI → nvte_ep_prepare
    FFI->>NCCL: ncclEpDispatch routing metadata
    Prim-->>EP: token_counts, handle_mem
    EP->>Prim: ep_dispatch_fwd (EpDispatchPrimitive)
    Prim->>FFI: EpDispatchFFI → nvte_ep_dispatch
    FFI->>NCCL: ncclEpDispatch tokens + weights
    Prim-->>EP: recv_tokens, recv_topk_weights
    EP-->>User: recv_tokens, recv_topk_weights, handle_mem, token_counts

    User->>User: Expert MLP (local computation)

    User->>EP: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
    EP->>Prim: ep_combine_fwd (EpCombinePrimitive)
    Prim->>FFI: EpCombineFFI → nvte_ep_combine
    FFI->>NCCL: ncclEpCombine scatter-sum
    Prim-->>EP: combined_tokens
    EP-->>User: combined_tokens

    Note over User,NCCL: Backward pass (custom_vjp)
    User->>EP: _dispatch_bwd(g_recv_tokens, g_recv_weights)
    EP->>Prim: ep_dispatch_bwd (EpDispatchBwdPrimitive)
    Prim->>FFI: EpDispatchBwdFFI → nvte_ep_dispatch_bwd
    Prim-->>EP: grad_tokens, grad_topk_weights

    User->>EP: _combine_bwd(g_combined)
    EP->>Prim: ep_combine_bwd (EpCombineBwdPrimitive)
    Prim->>FFI: EpCombineBwdFFI → nvte_ep_combine_bwd
    Prim-->>EP: grad_expert_out
Loading

Reviews (21): Last reviewed commit: "L0_jax_unittest: exclude multi-process E..." | Re-trigger Greptile

Comment thread build_tools/jax.py Outdated
Comment thread build_tools/jax.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions.h Outdated
jberchtold-nvidia pushed a commit to jberchtold-nvidia/TransformerEngine that referenced this pull request Jun 5, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 06f8a13 to c34771d Compare June 10, 2026 15:24
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@tdophung tdophung mentioned this pull request Jun 12, 2026
13 tasks
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 1e4c3ae to 5b49ecc Compare June 23, 2026 16:39
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

phu0ngng and others added 21 commits June 24, 2026 00:22
…16 max_token_dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled()

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment on lines +573 to +577
f" over [num_procs, recv_pr, H]; got spec={eo_spec}."
)
per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh)
arg_shardings = tuple(a.sharding for a in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ep_combine_fwd crashes with TypeError in SPMD when out_partition_spec=None

ep_combine_fwd has out_partition_spec=None as its default (line 916). When JAX's SPMD machinery calls EpCombinePrimitive.partition, two sites immediately fail on None:

  • Line 575: out_partition_spec[0]TypeError: 'NoneType' object is not subscriptable
  • Line 577: PartitionSpec(*out_partition_spec)TypeError: argument after * must be an iterable, not NoneType

Since ep_combine_fwd is exported in __all__ with a None default, any user calling it in a multi-device context without specifying out_partition_spec will hit a cryptic crash at JIT compilation time. The _combine_fwd wrapper always sets this value, so the custom_vjp path is safe, but the raw primitive path is not. A guard like if out_partition_spec is None: raise ValueError("out_partition_spec must be specified in SPMD mode") at the top of partition would produce a helpful error message instead.

@phu0ngng phu0ngng added the 2.7.0 label Jun 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants