Skip to content

Commit 57dd6a0

Browse files
RandySheriffpytorchmergebot
authored andcommitted
[OC][Torch] Extend autotune options for OC OBA 200x shapes (pytorch#166931)
Summary: Add four best configs for shapes of the OC OBA 200x model: ``` M=2048 N=2048 K=12288 triton_mm_35 0.1526 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=52416 K=1536 triton_mm_12 0.4604 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=12288 K=2048 triton_mm_9 0.1444 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=2048 K=52416 triton_mm_35 0.6505 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 ``` Test Plan: Run tritonbench for torch fp8(_scaled_mm) for all above shapes, e.g. ``` TRITON_PRINT_AUTOTUNING=1 buck2 run mode/opt-amd-gpu -c fbcode.enable_gpu_sections=true //pytorch/tritonbench:run -- --op fp8_gemm --only pt2_fp8_gemm --metrics tflops,accuracy --m 2048 --n 2048 --k 12288 ``` Differential Revision: D86158497 Pull Request resolved: pytorch#166931 Approved by: https://github.com/jananisriram
1 parent 7318ed6 commit 57dd6a0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/_inductor/template_heuristics/triton.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,12 @@ def __init__(self) -> None:
422422
GemmConfig(32, 256, 64, 6, 4),
423423
GemmConfig(64, 16, 256, 5, 4),
424424
GemmConfig(64, 32, 256, 5, 4),
425+
GemmConfig(64, 128, 128, 2, 4),
425426
GemmConfig(64, 128, 128, 3, 4),
427+
GemmConfig(128, 128, 128, 2, 4),
426428
GemmConfig(128, 256, 128, 4, 8),
429+
GemmConfig(256, 128, 128, 2, 4),
430+
GemmConfig(256, 128, 128, 2, 8),
427431
]
428432

429433
self.scaled_persistent_mm_configs: list[BaseConfig] = [

0 commit comments

Comments
 (0)