Skip to content

feat(triton): add rms_norm operator#736

Draft
fuyou4546 wants to merge 1 commit into
InfiniTensor:masterfrom
fuyou4546:feat/triton-rms-norm
Draft

feat(triton): add rms_norm operator#736
fuyou4546 wants to merge 1 commit into
InfiniTensor:masterfrom
fuyou4546:feat/triton-rms-norm

Conversation

@fuyou4546

Copy link
Copy Markdown
Contributor

Summary

  • Add RMSNorm operator on Triton backend (src/triton/ops/rms_norm/{rms_norm.py, build.py, rms_norm.h})

Motivation

Add an RMSNorm operator on the Triton backend, supporting fp16, bf16, and fp32.
One program per row with fp32 accumulation.

Closes N/A

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Smoke Test Result

python -m pytest tests -m smoke -q
.................ss..ss....s................................ssssss..ss................                                                                                                                    [100%]
73 passed, 13 skipped, 20864 deselected in 3.17s

Test Results on Supported Platforms

Platform Affected Build / Smoke Result Full Result / Notes
NVIDIA Successfully installed InfiniOps-0.1.0 / 73 passed, 13 skipped, 20864 deselected 36 passed, 72 deselected
Iluvatar
MetaX
Cambricon
Moore
Ascend
Full `pytest` output (optional)
 pytest tests/test_rms_norm.py -k "cuda-8"
======== test session starts ========
platform linux -- Python 3.12.0, pytest-9.0.3, pluggy-1.6.0
rootdir: /home/zhangshuo/projects/InfiniTensor/InfiniOps
configfile: pyproject.toml
plugins: xdist-3.8.0, cov-7.1.0
collected 108 items / 72 deselected / 36 selected                                                                                                                                                               

tests/test_rms_norm.py ....................................                                                                                                                                               [100%]

======== 36 passed, 72 deselected in 0.59s ========

Benchmark / Performance Impact

N/A

Notes for Reviewers

Depends on the AOT infrastructure in feat/triton-backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant