forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_mmdecomp.py
202 lines (157 loc) · 6.54 KB
/
test_mmdecomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Owner(s): ["module: nn"]
import math
import unittest
from typing import List, Tuple, Union
import torch
from torch._inductor import config
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
TEST_CUDA,
)
from torch.utils._triton import has_triton
default_atol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1e-5,
}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1.3e-6,
}
def rand_math_tensor(
shape: Tuple[Union[int, List[int]]],
device: str,
dtype: torch.dtype,
requires_grad: bool = False,
packed: bool = False,
) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
Args:
shape (Tuple[int]): Shape of Tensor to construct
device (str): which device to create tensor on
dtype (torch.dtype): Tensors' dtype
requires_grad (bool, optional): Tensors grad status. Defaults to False.
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
Returns:
torch.Tensor: A new tensor
"""
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
def init_tensor(tensor_list, **kwargs) -> torch.Tensor:
return torch.Tensor(tensor_list).to(**kwargs)
def run_comp_nocomp(function, *inputs, **kwargs):
c_function = torch.compile(function)
f_res = function(*inputs)
cf_res = c_function(*inputs)
if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))):
torch.testing.assert_close(f_res, cf_res, **kwargs)
# The test functions are used by several tests
def torch_mm(a, b):
return torch.mm(a, b)
def torch_addmm(add, b, c):
return torch.addmm(add, b, c)
def torch_bmm(a, b):
return torch.bmm(a, b)
def torch_baddbmm(add, b, c, alpha, beta):
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
# The shapes we test on
ts_list = [
(1, 32, 32, 1),
(1, 10, 10, 1),
(1, 3, 3, 1),
(32, 1, 1, 32),
(3, 1, 1, 3),
(4, 1, 1, 9),
(9, 1, 1, 4),
]
class TestDecomp(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16])
def test_simple_mm(self, device, dtype):
fudge = 10
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol)
run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton")
@parametrize(
"dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float]
)
@parametrize("bs", [1, 2, 4, 10])
def test_batched_mm(self, device, dtype, bs):
fudge = 3
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
for alpha in (0, 1, -1, 0.5, -0.5):
for beta in (0, 1, -1, 0.5, -0.5):
run_comp_nocomp(
torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol
)
@unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton")
@config.patch(coordinate_descent_tuning=True)
def test_bmm_batch2_last_dim_size_is_one(self, device):
fudge = 3
rtol = default_rtol[torch.float32] * fudge
atol = default_atol[torch.float32] * fudge
t1 = torch.randn(1, 32, 2, device=device)
t2 = torch.randn(1, 2, 1, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
def test_some(self, device, dtype):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual parms in skipIf
if device.startswith("cuda") and dtype == torch.int:
return
run_comp_nocomp(
torch_mm,
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
)
run_comp_nocomp(
torch_mm,
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
)
@unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
@parametrize("bs", [1, 2, 4, 10])
def test_some_batched(self, device, dtype, bs):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual parms in skipIf
if device.startswith("cuda") and dtype == torch.int:
return
run_comp_nocomp(
torch_bmm,
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
)
run_comp_nocomp(
torch_bmm,
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
)
device_types = ("cpu", "cuda")
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
if __name__ == "__main__":
# We don't support torch.compile() on Windows
if not IS_WINDOWS:
run_tests()