[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
[Common] Support scaled & clamped swiglu, srelu for BF16 #3132zhongbozhu wants to merge 6 commits into
Conversation
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds six new CUDA kernels (
Confidence Score: 4/5Safe to merge; the new kernels are mathematically consistent with the existing utility functions and the test suite covers the primary code paths for both contiguous and interleaved GLU layouts. The core kernel math, alignment dispatch, and block reduction are correct. The only items worth addressing before shipping are: FP16 is absent from the test dtype sweep even though the dispatch macro includes it, the one-block-per-row launch casts
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
|
| } | ||
| } |
There was a problem hiding this comment.
gated_unscaled computes unscaled on line 170, but gated_grads unconditionally writes *unscaled on line 171, overwriting it. The first call is dead code — every gated_grads case sets *unscaled before returning, so the result of gated_unscaled is never observed. This should simply be removed.
There was a problem hiding this comment.
either out-dated or not precise comment
There was a problem hiding this comment.
You're right that "dead code" was imprecise. The more accurate characterization: gated_unscaled is a redundant computation — its return value is assigned to unscaled, but gated_grads immediately overwrites *unscaled in all three switch cases before returning, so the first assignment is discarded. The function itself is reachable and executes, but its result is never observed.
The substance still holds: the call to gated_unscaled on line 170 can be removed without any behavioral change, since gated_grads on line 171 already computes and sets unscaled as a side-output needed for the scale gradient. Keeping both creates a mild readability issue — a reader might assume the two calls serve different purposes.
A cleaner form would be:
| } | |
| } | |
| gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); |
Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.
|
/te-ci pytorch |
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. Left minor comments
| __device__ __forceinline__ float warp_reduce_sum(float value) { | ||
| #pragma unroll | ||
| for (int offset = THREADS_PER_WARP / 2; offset > 0; offset >>= 1) { | ||
| value += __shfl_down_sync(0xffffffff, value, offset); | ||
| } | ||
| return value; | ||
| } | ||
|
|
||
| __device__ __forceinline__ float block_reduce_sum(float value, float *smem) { | ||
| const int lane = threadIdx.x % THREADS_PER_WARP; | ||
| const int warp = threadIdx.x / THREADS_PER_WARP; | ||
|
|
||
| value = warp_reduce_sum(value); |
There was a problem hiding this comment.
I beleive we can reuse this from utils.cuh
TransformerEngine/transformer_engine/common/utils.cuh
Lines 491 to 492 in 77054fa
| void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, | ||
| int64_t glu_interleave_size, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_swiglu); | ||
| using namespace transformer_engine; | ||
| Empty empty = {}; | ||
| (void)empty; | ||
| ClampedSwiGLUParam param = {}; | ||
| launch_scaled_gated_forward<ScaledActivation::kSwiGLU>( | ||
| input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); | ||
| } | ||
|
|
||
| void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, | ||
| NVTETensor grad_input, NVTETensor grad_act_scales, | ||
| int64_t glu_interleave_size, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_dswiglu); | ||
| using namespace transformer_engine; | ||
| ClampedSwiGLUParam param = {}; | ||
| launch_scaled_gated_backward<ScaledActivation::kSwiGLU>(grad, input, act_scales, grad_input, | ||
| grad_act_scales, glu_interleave_size, | ||
| param, stream, "nvte_scaled_dswiglu"); | ||
| } | ||
|
|
||
| void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, | ||
| NVTETensor output, float limit, float alpha, | ||
| float glu_linear_offset, int64_t glu_interleave_size, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_clamped_swiglu); | ||
| using namespace transformer_engine; | ||
| ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; | ||
| launch_scaled_gated_forward<ScaledActivation::kClampedSwiGLU>( | ||
| input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_clamped_swiglu"); | ||
| } | ||
|
|
||
| void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, | ||
| const NVTETensor act_scales, NVTETensor grad_input, | ||
| NVTETensor grad_act_scales, float limit, float alpha, | ||
| float glu_linear_offset, int64_t glu_interleave_size, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_clamped_dswiglu); | ||
| using namespace transformer_engine; | ||
| ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; | ||
| launch_scaled_gated_backward<ScaledActivation::kClampedSwiGLU>( | ||
| grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, | ||
| "nvte_scaled_clamped_dswiglu"); | ||
| } | ||
|
|
||
| void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_srelu); | ||
| using namespace transformer_engine; | ||
| launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); | ||
| } | ||
|
|
||
| void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, | ||
| NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_scaled_dsrelu); | ||
| using namespace transformer_engine; | ||
| launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, | ||
| "nvte_scaled_dsrelu"); | ||
| } |
There was a problem hiding this comment.
Might be good to move these NVTE API definitions into new files scaled_swiglu.cu and scaled_srelu.cu, following the footsteps of other activation definitions.
| const auto compute_grad_scales = std::get<5>(GetParam()); | ||
|
|
||
| if (activation == ScaledActivationCase::kSReLU && interleave != 0) { | ||
| GTEST_SKIP() << "SReLU is not a GLU activation."; |
There was a problem hiding this comment.
Nit:
| GTEST_SKIP() << "SReLU is not a GLU activation."; | |
| GTEST_SKIP() << "Interleave has no meaning for SReLU."; |
Description
Support Mega-C++ with Cublas BF16 Grouped GEMM backend: #3099
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: