forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_models.py
257 lines (202 loc) · 8.9 KB
/
generate_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import io
import logging
import sys
import zipfile
from pathlib import Path
from typing import Set
import torch
# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403
from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
"""
This file is used to generate model for test operator change. Please refer to
https://github.com/pytorch/rfcs/blob/master/RFC-0017-PyTorch-Operator-Versioning.md for more details.
A systematic workflow to change operator is needed to ensure
Backwards Compatibility (BC) / Forwards Compatibility (FC) for operator changes. For BC-breaking operator change,
an upgrader is needed. Here is the flow to properly land a BC-breaking operator change.
1. Write an upgrader in caffe2/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp file. The softly enforced
naming format is <operator_name>_<operator_overload>_<start>_<end>. For example, the below example means that
div.Tensor at version from 0 to 3 needs to be replaced by this upgrader.
```
/*
div_Tensor_0_3 is added for a change of operator div in pr xxxxxxx.
Create date: 12/02/2021
Expire date: 06/02/2022
*/
{"div_Tensor_0_3", R"SCRIPT(
def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
```
2. In caffe2/torch/csrc/jit/operator_upgraders/version_map.h, add changes like below.
You will need to make sure that the entry is SORTED according to the version bump number.
```
{"div.Tensor",
{{4,
"div_Tensor_0_3",
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
```
3. After rebuild PyTorch, run the following command and it will auto generate a change to
fbcode/caffe2/torch/csrc/jit/mobile/upgrader_mobile.cpp
```
python pytorch/torchgen/operator_versions/gen_mobile_upgraders.py
```
4. Generate the test to cover upgrader.
4.1 Switch the commit before the operator change, and add a module in
`test/jit/fixtures_srcs/fixtures_src.py`. The reason why switching to commit is that,
an old model with the old operator before the change is needed to ensure the upgrader
is working as expected. In `test/jit/fixtures_srcs/generate_models.py`, add the module and
it's corresponding changed operator like following
```
ALL_MODULES = {
TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
}
```
This module should includes the changed operator. If the operator isn't covered in the model,
the model export process in step 4.2 will fail.
4.2 Export the model to `test/jit/fixtures` by running
```
python /Users/chenlai/pytorch/test/jit/fixtures_src/generate_models.py
```
4.3 In `test/jit/test_save_load_for_op_version.py`, add a test to cover the old models and
ensure the result is equivalent between current module and old module + upgrader.
4.4 Save all change in 4.1, 4.2 and 4.3, as well as previous changes made in step 1, 2, 3.
Submit a pr
"""
"""
A map of test modules and it's according changed operator
key: test module
value: changed operator
"""
ALL_MODULES = {
TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
TestVersionedLinspaceV7(): "aten::linspace",
TestVersionedLinspaceOutV7(): "aten::linspace.out",
TestVersionedLogspaceV8(): "aten::logspace",
TestVersionedLogspaceOutV8(): "aten::logspace.out",
TestVersionedGeluV9(): "aten::gelu",
TestVersionedGeluOutV9(): "aten::gelu.out",
TestVersionedRandomV10(): "aten::random_.from",
TestVersionedRandomFuncV10(): "aten::random.from",
TestVersionedRandomOutV10(): "aten::random.from_out",
}
"""
Get the path to `test/jit/fixtures`, where all test models for operator changes
(upgrader/downgrader) are stored
"""
def get_fixtures_path() -> Path:
pytorch_dir = Path(__file__).resolve().parents[3]
fixtures_path = pytorch_dir / "test" / "jit" / "fixtures"
return fixtures_path
"""
Get all models' name in `test/jit/fixtures`
"""
def get_all_models(model_directory_path: Path) -> Set[str]:
files_in_fixtures = model_directory_path.glob("**/*")
all_models_from_fixtures = [
fixture.stem for fixture in files_in_fixtures if fixture.is_file()
]
return set(all_models_from_fixtures)
"""
Check if a given model already exist in `test/jit/fixtures`
"""
def model_exist(model_file_name: str, all_models: Set[str]) -> bool:
return model_file_name in all_models
"""
Get the operator list given a module
"""
def get_operator_list(script_module: torch) -> Set[str]:
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
operator_list = _export_operator_list(mobile_module)
return operator_list
"""
Get the output model operator version, given a module
"""
def get_output_model_version(script_module: torch.nn.Module) -> int:
buffer = io.BytesIO()
torch.jit.save(script_module, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
try:
version = int(zipped_model.read("archive/version").decode("utf-8"))
return version
except KeyError:
version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
return version
"""
Loop through all test modules. If the corresponding model doesn't exist in
`test/jit/fixtures`, generate one. For the following reason, a model won't be exported:
1. The test module doens't cover the changed operator. For example, test_versioned_div_tensor_example_v4
is supposed to test the operator aten::div.Tensor. If the model doesn't include this operator, it will fail.
The error message includes the actual operator list from the model.
2. The output model version is not the same as expected version. For example, test_versioned_div_tensor_example_v4
is used to test an operator change aten::div.Tensor, and the operator version will be bumped to v5. This script is
supposed to run before the operator change (before the commit to make the change). If the actual model version is v5,
likely this script is running with the commit to make the change.
3. The model already exists in `test/jit/fixtures`.
"""
def generate_models(model_directory_path: Path):
all_models = get_all_models(model_directory_path)
for a_module, expect_operator in ALL_MODULES.items():
# For example: TestVersionedDivTensorExampleV7
torch_module_name = type(a_module).__name__
if not isinstance(a_module, torch.nn.Module):
logger.error(
"The module %s "
"is not a torch.nn.module instance. "
"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
"and it's registered as an instance in ALL_MODULES in generated_models.py",
torch_module_name,
)
# The corresponding model name is: test_versioned_div_tensor_example_v4
model_name = "".join(
[
"_" + char.lower() if char.isupper() else char
for char in torch_module_name
]
).lstrip("_")
# Some models may not compile anymore, so skip the ones
# that already has pt file for them.
logger.info("Processing %s", torch_module_name)
if model_exist(model_name, all_models):
logger.info("Model %s already exists, skipping", model_name)
continue
script_module = torch.jit.script(a_module)
actual_model_version = get_output_model_version(script_module)
current_operator_version = torch._C._get_max_operator_version()
if actual_model_version >= current_operator_version + 1:
logger.error(
"Actual model version %s "
"is equal or larger than %s + 1. "
"Please run the script before the commit to change operator.",
actual_model_version,
current_operator_version,
)
continue
actual_operator_list = get_operator_list(script_module)
if expect_operator not in actual_operator_list:
logger.error(
"The model includes operator: %s, "
"however it doesn't cover the operator %s."
"Please ensure the output model includes the tested operator.",
actual_operator_list,
expect_operator,
)
continue
export_model_path = str(model_directory_path / (str(model_name) + ".ptl"))
script_module._save_for_lite_interpreter(export_model_path)
logger.info(
"Generating model %s and it's save to %s", model_name, export_model_path
)
def main() -> None:
model_directory_path = get_fixtures_path()
generate_models(model_directory_path)
if __name__ == "__main__":
main()