1
1
#!/usr/bin/env python3
2
+ """
3
+ Generates quick start module for https://pytorch.org/get-started/locally/ page
4
+ If called from update-quick-start-module.yml workflow (--autogenerate parameter set)
5
+ Will output new quick-start-module.js, and new published_version.json file
6
+ based on the current release matrix.
7
+ If called standalone will generate quick-start-module.js from existing
8
+ published_version.json file
9
+ """
10
+
2
11
import json
3
- import os
12
+ import copy
4
13
import argparse
5
- import io
6
- import sys
7
14
from pathlib import Path
8
- from typing import Dict , Set , List , Iterable
15
+ from typing import Dict
9
16
from enum import Enum
10
17
11
- BASE_DIR = os . path . dirname ( os . path . dirname ( os . path . abspath ( __file__ )))
18
+ BASE_DIR = Path ( __file__ ). parent . parent
12
19
13
20
class OperatingSystem (Enum ):
14
21
LINUX : str = "linux"
@@ -27,7 +34,6 @@ class OperatingSystem(Enum):
27
34
# TBD drive the mapping via:
28
35
# 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
29
36
# 2. Possibility to override the scanning algorithm with arguments passed from workflow
30
-
31
37
acc_arch_ver_map = {
32
38
"nightly" : {
33
39
"accnone" : ("cpu" , "" ),
@@ -50,106 +56,87 @@ class OperatingSystem(Enum):
50
56
DEBUG : "Download here (Debug version):" ,
51
57
}
52
58
59
+ def load_json_from_basedir (filename : str ):
60
+ try :
61
+ with open (BASE_DIR / filename ) as fptr :
62
+ return json .load (fptr )
63
+ except FileNotFoundError as exc :
64
+ raise ImportError (f"File { filename } not found error: { exc .strerror } " ) from exc
65
+ except json .JSONDecodeError as exc :
66
+ raise ImportError (f"Invalid JSON { filename } " ) from exc
67
+
53
68
def read_published_versions ():
54
- with open (os .path .join (BASE_DIR , "published_versions.json" )) as fp :
55
- return json .load (fp )
69
+ return load_json_from_basedir ("published_versions.json" )
56
70
57
71
def write_published_versions (versions ):
58
- with open (os .path .join (BASE_DIR , "published_versions.json" ), "w" ) as outfile :
59
- json .dump (versions , outfile , indent = 2 )
60
-
61
- def read_matrix_for_os (osys : OperatingSystem , value : str ):
62
- try :
63
- with open (os .path .join (BASE_DIR , f"{ osys .value } _{ value } _matrix.json" )) as fp :
64
- return json .load (fp )["include" ]
65
- except FileNotFoundError as e :
66
- raise ImportError (f"Release matrix not found for: { osys .value } error: { e .strerror } " ) from e
72
+ with open (BASE_DIR / "published_versions.json" , "w" ) as outfile :
73
+ json .dump (versions , outfile , indent = 2 )
67
74
75
+ def read_matrix_for_os (osys : OperatingSystem , channel : str ):
76
+ jsonfile = load_json_from_basedir (f"{ osys .value } _{ channel } _matrix.json" )
77
+ return jsonfile ["include" ]
68
78
69
79
def read_quick_start_module_template ():
70
- with open (os .path .join (BASE_DIR , "_includes" , "quick-start-module.js" )) as fp :
71
- return fp .read ()
72
-
80
+ with open (BASE_DIR / "_includes" / "quick-start-module.js" ) as fptr :
81
+ return fptr .read ()
82
+
83
+ def get_package_type (pkg_key : str , os_key : OperatingSystem ) -> str :
84
+ if pkg_key != "pip" :
85
+ return pkg_key
86
+ return "manywheel" if os_key == OperatingSystem .LINUX .value else "wheel"
87
+
88
+ def get_gpu_info (acc_key , instr , acc_arch_map ):
89
+ gpu_arch_type , gpu_arch_version = acc_arch_map [acc_key ]
90
+ if DEFAULT in instr :
91
+ gpu_arch_type , gpu_arch_version = acc_arch_map ["accnone" ]
92
+ return (gpu_arch_type , gpu_arch_version )
93
+
94
+ # This method is used for generating new published_versions.json file
95
+ # It will modify versions json object with installation instructions
96
+ # Provided by generate install matrix Github Workflow, stored in release_matrix
97
+ # json object.
73
98
def update_versions (versions , release_matrix , release_version ):
74
- version_map = {
75
- "preview" : "preview" ,
76
- }
77
- version = ""
99
+ version = "preview"
78
100
acc_arch_map = acc_arch_ver_map [release_version ]
79
101
80
- if (release_version == "nightly" ):
81
- version = "preview"
82
- else :
102
+ if release_version != "nightly" :
83
103
version = release_matrix [OperatingSystem .LINUX .value ][0 ]["stable_version" ]
84
-
85
- # Generating for a specific version
86
- if (version != "preview" ):
87
- version_map = {
88
- version : version ,
89
- }
90
104
if version not in versions ["versions" ]:
91
- import copy
92
- new_version = copy .deepcopy (versions ["versions" ]["preview" ])
93
- versions ["versions" ][version ] = new_version
105
+ versions ["versions" ][version ] = copy .deepcopy (versions ["versions" ]["preview" ])
94
106
versions ["latest_stable" ] = version
95
107
96
108
# Perform update of the json file from release matrix
97
- for ver , ver_key in version_map .items ():
98
- for os_key , os_vers in versions ["versions" ][ver_key ].items ():
99
- for pkg_key , pkg_vers in os_vers .items ():
100
- for acc_key , instr in pkg_vers .items ():
101
-
102
- package_type = pkg_key
103
- if pkg_key == 'pip' :
104
- package_type = 'manywheel' if os_key == OperatingSystem .LINUX .value else 'wheel'
105
-
106
- gpu_arch_type , gpu_arch_version = acc_arch_map [acc_key ]
107
- if (DEFAULT in instr ):
108
- gpu_arch_type , gpu_arch_version = acc_arch_map ["accnone" ]
109
-
110
- pkg_arch_matrix = list (filter (
111
- lambda x :
112
- (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ]) ==
113
- (package_type , gpu_arch_type , gpu_arch_version ),
114
- release_matrix [os_key ]
115
- ))
116
-
117
- if pkg_arch_matrix :
118
- if package_type != 'libtorch' :
119
- instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
120
- else :
121
- if os_key == OperatingSystem .LINUX .value :
122
- rel_entry_pre_cxx1 = next (filter (
123
- lambda x :
124
- x ["devtoolset" ] == PRE_CXX11_ABI ,
125
- pkg_arch_matrix
126
- ), None )
127
- rel_entry_cxx1_abi = next (filter (
128
- lambda x :
129
- x ["devtoolset" ] == CXX11_ABI ,
130
- pkg_arch_matrix
131
- ), None )
132
- if (instr ['versions' ] is not None ):
133
- instr ['versions' ][LIBTORCH_DWNL_INSTR [PRE_CXX11_ABI ]] = rel_entry_pre_cxx1 ["installation" ]
134
- instr ['versions' ][LIBTORCH_DWNL_INSTR [CXX11_ABI ]] = rel_entry_cxx1_abi ["installation" ]
135
- elif os_key == OperatingSystem .WINDOWS .value :
136
- rel_entry_release = next (filter (
137
- lambda x :
138
- x ["libtorch_config" ] == RELEASE ,
139
- pkg_arch_matrix
140
- ), None )
141
- rel_entry_debug = next (filter (
142
- lambda x :
143
- x ["libtorch_config" ] == DEBUG ,
144
- pkg_arch_matrix
145
- ), None )
146
- if (instr ['versions' ] is not None ):
147
- instr ['versions' ][LIBTORCH_DWNL_INSTR [RELEASE ]] = rel_entry_release ["installation" ]
148
- instr ['versions' ][LIBTORCH_DWNL_INSTR [DEBUG ]] = rel_entry_debug ["installation" ]
149
-
150
-
109
+ for os_key , os_vers in versions ["versions" ][version ].items ():
110
+ for pkg_key , pkg_vers in os_vers .items ():
111
+ for acc_key , instr in pkg_vers .items ():
112
+ package_type = get_package_type (pkg_key , os_key )
113
+ gpu_arch_type , gpu_arch_version = get_gpu_info (acc_key , instr , acc_arch_map )
114
+
115
+ pkg_arch_matrix = [
116
+ x for x in release_matrix [os_key ]
117
+ if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ]) ==
118
+ (package_type , gpu_arch_type , gpu_arch_version )
119
+ ]
120
+
121
+ if pkg_arch_matrix :
122
+ if package_type != "libtorch" :
123
+ instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
124
+ else :
125
+ if os_key == OperatingSystem .LINUX .value :
126
+ rel_entry_dict = {x ["devtoolset" ]: x ["installation" ] for x in pkg_arch_matrix }
127
+ if instr ["versions" ] is not None :
128
+ for ver in [PRE_CXX11_ABI , CXX11_ABI ]:
129
+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
130
+ elif os_key == OperatingSystem .WINDOWS .value :
131
+ rel_entry_dict = {x ["libtorch_config" ]: x ["installation" ] for x in pkg_arch_matrix }
132
+ if instr ["versions" ] is not None :
133
+ for ver in [RELEASE , DEBUG ]:
134
+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
135
+
136
+ # This method is used for generating new quick-start-module.js
137
+ # from the versions json object
151
138
def gen_install_matrix (versions ) -> Dict [str , str ]:
152
- rc = {}
139
+ result = {}
153
140
version_map = {
154
141
"preview" : "preview" ,
155
142
"stable" : versions ["latest_stable" ],
@@ -158,36 +145,30 @@ def gen_install_matrix(versions) -> Dict[str, str]:
158
145
for os_key , os_vers in versions ["versions" ][ver_key ].items ():
159
146
for pkg_key , pkg_vers in os_vers .items ():
160
147
for acc_key , instr in pkg_vers .items ():
161
- extra_key = 'python' if pkg_key != 'libtorch' else 'cplusplus'
162
- key = f"{ ver } ,{ pkg_key } ,{ os_key } ,{ acc_key } ,{ extra_key } "
163
- note = instr ["note" ]
164
- lines = [note ] if note is not None else []
165
- if pkg_key == "libtorch" :
166
- ivers = instr ["versions" ]
167
- if ivers is not None :
168
- lines += [f"{ lab } <br /><a href='{ val } '>{ val } </a>" for (lab , val ) in ivers .items ()]
169
- else :
170
- command = instr ["command" ]
171
- if command is not None :
172
- lines .append (command )
173
- rc [key ] = "<br />" .join (lines )
174
- return rc
148
+ extra_key = 'python' if pkg_key != 'libtorch' else 'cplusplus'
149
+ key = f"{ ver } ,{ pkg_key } ,{ os_key } ,{ acc_key } ,{ extra_key } "
150
+ note = instr ["note" ]
151
+ lines = [note ] if note is not None else []
152
+ if pkg_key == "libtorch" :
153
+ ivers = instr ["versions" ]
154
+ if ivers is not None :
155
+ lines += [f"{ lab } <br /><a href='{ val } '>{ val } </a>" for (lab , val ) in ivers .items ()]
156
+ else :
157
+ command = instr ["command" ]
158
+ if command is not None :
159
+ lines .append (command )
160
+ result [key ] = "<br />" .join (lines )
161
+ return result
175
162
176
163
def main ():
177
164
parser = argparse .ArgumentParser ()
178
- parser .add_argument (
179
- "--autogenerate" ,
180
- help = "Is this call being initiated from workflow? update published_versions" ,
181
- type = str ,
182
- choices = [ENABLE , DISABLE ],
183
- default = ENABLE ,
184
- )
165
+ parser .add_argument ('--autogenerate' , dest = 'autogenerate' , action = 'store_true' )
166
+ parser .set_defaults (autogenerate = False )
185
167
186
168
options = parser .parse_args ()
187
169
versions = read_published_versions ()
188
170
189
-
190
- if options .autogenerate == ENABLE :
171
+ if options .autogenerate :
191
172
release_matrix = {}
192
173
for val in ("nightly" , "release" ):
193
174
release_matrix [val ] = {}
@@ -199,14 +180,11 @@ def main():
199
180
200
181
write_published_versions (versions )
201
182
202
-
203
183
template = read_quick_start_module_template ()
204
184
versions_str = json .dumps (gen_install_matrix (versions ))
205
185
template = template .replace ("{{ installMatrix }}" , versions_str )
206
186
template = template .replace ("{{ VERSION }}" , f"\" Stable ({ versions ['latest_stable' ]} )\" " )
207
187
print (template .replace ("{{ ACC ARCH MAP }}" , json .dumps (acc_arch_ver_map )))
208
188
209
-
210
-
211
189
if __name__ == "__main__" :
212
190
main ()
0 commit comments