Skip to content

Commit 2a03bf5

Browse files
FindHaopytorchmergebot
authored andcommitted
[inductor] fix grid z bug for large grid (pytorch#127448)
Fixes pytorch#123210 https://github.com/pytorch/pytorch/blob/2f3d3ddd70e553d4c5269df699489b82b3aa25ab/torch/_inductor/runtime/triton_heuristics.py#L1733-L1753 If a kernel's y_grid is larger than 65535, it will be split into multiple z grids. The above grad_fn does this split before the kernel launch; however, the computations for yoffset and the y_grid are incorrect. For example, if we have xy numel of `(1*XBLOCK, 65537*YBLOCK)`, this function will return an [xyz]_grid with (1, 32768, 2). XBLOCK and YBLOCK here are used for the following `get_grid_dim`. Let's use their default values (4, 1024). https://github.com/pytorch/pytorch/blob/2f3d3ddd70e553d4c5269df699489b82b3aa25ab/torch/_inductor/runtime/triton_heuristics.py#L1734 [xyz]_grid = (1, 32768, 2) means the workload are divided to two z grids. Because the triton kernel generation still follows xy dimension, one of the exampled generated kernel is shown below. ```python @triton.jit def triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): ynumel = 65537*1024 xnumel = 1*4 yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK yindex = yoffset + tl.arange(0, YBLOCK)[None, :] ymask = yindex < ynumel xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel x2 = xindex y0 = yindex % 128 y1 = (yindex // 128) y3 = yindex tmp0 = tl.load(in_ptr0 + (y0 + (128*x2) + (512*y1)), xmask, eviction_policy='evict_last') tl.store(out_ptr0 + (x2 + (4*y3)), tmp0, xmask) ``` For a trition block with xyz index (0, 0, 1), its yoffset and xoffset are both 0s based on the compuation `yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK` and `xoffset = tl.program_id(0) * XBLOCK`. So, this triton block will access the very first elements of the input. However, the correct yoffset should be `(y_index + z_index * y_grid ) * YBLOCK` which is the starting position of the 2nd z grid. At the same time, because we used `y_grid = y_grid // div` to compute the maximum number of element in y dimension, the y_grid is 32768. The total y grids is 32768*2 = 65536, which is less than the actual y grids 65537. So, we should use `y_grid = ceildiv(y_grid, div)` to compute the y grid to save the remaining grids. pytorch#123210 is not about AOTInductor, the root cause is the triton kernel generated by torchinductor. Pull Request resolved: pytorch#127448 Approved by: https://github.com/eellison
1 parent 4935a01 commit 2a03bf5

File tree

5 files changed

+33
-25
lines changed

5 files changed

+33
-25
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -968,29 +968,19 @@ class Model(torch.nn.Module):
968968
def __init__(self):
969969
super().__init__()
970970

971-
def forward(self, primals_1, primals_2, primals_5):
972-
view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128])
971+
def forward(self, primals_5):
972+
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
973973
primals_5 = None
974974
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
975975
clone = torch.ops.aten.clone.default(
976976
permute, memory_format=torch.contiguous_format
977977
)
978-
permute = None
979-
view_1 = torch.ops.aten.reshape.default(clone, [-1, 4])
980-
clone = None
981-
permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0])
982-
primals_1 = None
983-
addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1)
984-
primals_2 = None
985-
return addmm
986-
987-
s0 = 727828
988-
s1 = 512
989-
example_inputs = (
990-
torch.rand(2, 4, device=self.device),
991-
torch.rand(2, device=self.device),
992-
torch.rand(s0, s1, device=self.device),
993-
)
978+
return clone
979+
980+
# let y_grid = 65537
981+
s0 = 16777472
982+
s1 = 8
983+
example_inputs = (torch.rand(s0, s1, device=self.device),)
994984
self.check_model(Model(), example_inputs)
995985

996986
def test_cond_simple(self):
@@ -3065,7 +3055,6 @@ def fail_non_abi_compatible_cuda(is_skip=False):
30653055

30663056
CUDA_TEST_FAILURES = {
30673057
# test_failures, xfail by default, set is_skip=True to skip
3068-
"test_large_grid": fail_cuda(),
30693058
"test_normal_functional": fail_abi_compatible_cuda(is_skip=True),
30703059
# no runtime checks for non_abi_compatible mode
30713060
"test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True),

test/inductor/test_torchinductor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10341,6 +10341,23 @@ def test_generate_rand_fp8(self):
1034110341
t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn)
1034210342
self.assertTrue(t.dtype is torch.float8_e4m3fn)
1034310343

10344+
def test_large_grid(self):
10345+
# https://github.com/pytorch/pytorch/issues/123210
10346+
def fn(primals_5):
10347+
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
10348+
primals_5 = None
10349+
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
10350+
clone = torch.ops.aten.clone.default(
10351+
permute, memory_format=torch.contiguous_format
10352+
)
10353+
return clone
10354+
10355+
s0 = 16777472
10356+
s1 = 8
10357+
compiled_fn = torch._dynamo.optimize()(fn)
10358+
actual = compiled_fn(torch.ones(s0, s1))
10359+
self.assertTrue((actual == 1).all())
10360+
1034410361

1034510362
@dataclasses.dataclass
1034610363
class TestFailure:

test/inductor/test_triton_heuristics.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_triton_config(self):
3838

3939
def _test_artificial_zgrid(self):
4040
def forward(primals_1, primals_2, primals_5):
41-
view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128])
41+
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
4242
primals_5 = None
4343
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
4444
clone = torch.ops.aten.clone.default(
@@ -53,8 +53,8 @@ def forward(primals_1, primals_2, primals_5):
5353
primals_2 = None
5454
return addmm
5555

56-
s0 = 727828
57-
s1 = 512
56+
s0 = 16777472
57+
s1 = 8
5858

5959
args = [
6060
torch.rand([2, 4], device=GPU_TYPE),
@@ -73,7 +73,6 @@ def forward(primals_1, primals_2, primals_5):
7373
]
7474
self.assertEqual(forward(*args), foo_c(*args))
7575

76-
@unittest.skip("https://github.com/pytorch/pytorch/issues/123210")
7776
@expectedFailureXPU
7877
def test_artificial_zgrid(self):
7978
self._test_artificial_zgrid()

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2341,7 +2341,10 @@ def iteration_ranges_get_pid(self, entry):
23412341
and not entry.has_zdim
23422342
and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid())
23432343
):
2344-
key = f"{key} * (tl.program_id({entry.grid_dim + 1}) + 1)"
2344+
# For ynumel larger than max_ygrid, we need to use zdim.
2345+
# For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
2346+
# So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
2347+
key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))"
23452348
pid = entry.pid_cache.get(key, key)
23462349
if self.index_dtype != "tl.int32":
23472350
return f"{pid}.to({self.index_dtype})"

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1738,7 +1738,7 @@ def grid_fn(meta):
17381738
max_y_grid = get_max_y_grid()
17391739
if znumel is None:
17401740
div = ceildiv(y_grid, max_y_grid)
1741-
y_grid = y_grid // div
1741+
y_grid = ceildiv(y_grid, div)
17421742
z_grid = div
17431743
else:
17441744
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))

0 commit comments

Comments
 (0)