1
1
#!/usr/bin/env python3
2
2
import json
3
3
import os
4
- from typing import Dict
4
+ import argparse
5
+ import io
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Dict , Set , List , Iterable
9
+ from enum import Enum
10
+
5
11
BASE_DIR = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
6
12
13
+ class OperatingSystem (Enum ):
14
+ LINUX : str = "linux"
15
+ WINDOWS : str = "windows"
16
+ MACOS : str = "macos"
17
+
18
+ PRE_CXX11_ABI = "pre-cxx11"
19
+ CXX11_ABI = "cxx11-abi"
20
+ DEBUG = "debug"
21
+ RELEASE = "release"
22
+ DEFAULT = "default"
23
+ ENABLE = "enable"
24
+ DISABLE = "disable"
7
25
26
+ # Mapping json to release matrix is here for now
27
+ # TBD drive the mapping via:
28
+ # 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
29
+ # 2. Possibility to override the scanning algorithm with arguments passed from workflow
30
+ acc_arch_map = {
31
+ "accnone" : ("cpu" , "" ),
32
+ "cuda.x" : ("cuda" , "11.6" ),
33
+ "cuda.y" : ("cuda" , "11.7" ),
34
+ "rocm5.x" : ("rocm" , "5.2" )
35
+ }
36
+
37
+ LIBTORCH_DWNL_INSTR = {
38
+ PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
39
+ CXX11_ABI : "Download here (cxx11 ABI):" ,
40
+ RELEASE : "Download here (Release version):" ,
41
+ DEBUG : "Download here (Debug version):" ,
42
+ }
8
43
9
44
def read_published_versions ():
10
45
with open (os .path .join (BASE_DIR , "published_versions.json" )) as fp :
11
46
return json .load (fp )
12
47
48
+ def write_published_versions (versions ):
49
+ with open (os .path .join (BASE_DIR , "published_versions.json" ), "w" ) as outfile :
50
+ json .dump (versions , outfile , indent = 2 )
51
+
52
+ def read_matrix_for_os (osys : OperatingSystem ):
53
+ try :
54
+ with open (os .path .join (BASE_DIR , f"{ osys .value } _matrix.json" )) as fp :
55
+ return json .load (fp )["include" ]
56
+ except FileNotFoundError as e :
57
+ raise ImportError (f"Release matrix not found for: { osys .value } error: { e .strerror } " ) from e
58
+
13
59
14
60
def read_quick_start_module_template ():
15
61
with open (os .path .join (BASE_DIR , "_includes" , "quick-start-module.js" )) as fp :
16
62
return fp .read ()
17
63
64
+ def update_versions (versions , release_matrix , version ):
65
+ version_map = {
66
+ "preview" : "preview" ,
67
+ }
68
+
69
+ # Generating for a specific version
70
+ if (version != "preview" ):
71
+ version_map = {
72
+ version : version ,
73
+ }
74
+ if version in versions ["versions" ]:
75
+ if version != versions ["latest_stable" ]:
76
+ raise RuntimeError (f"Can only update prview, latest stable: { versions ['latest_stable' ]} or new version" )
77
+ else :
78
+ import copy
79
+ new_version = copy .deepcopy (versions ["versions" ]["preview" ])
80
+ versions ["versions" ][version ] = new_version
81
+ versions ["latest_stable" ] = version
82
+
83
+ # Perform update of the json file from release matrix
84
+ for ver , ver_key in version_map .items ():
85
+ for os_key , os_vers in versions ["versions" ][ver_key ].items ():
86
+ for pkg_key , pkg_vers in os_vers .items ():
87
+ for acc_key , instr in pkg_vers .items ():
88
+
89
+ package_type = pkg_key
90
+ if pkg_key == 'pip' :
91
+ package_type = 'manywheel' if os_key == OperatingSystem .LINUX .value else 'wheel'
92
+
93
+ gpu_arch_type , gpu_arch_version = acc_arch_map [acc_key ]
94
+ if (DEFAULT in instr ):
95
+ gpu_arch_type , gpu_arch_version = acc_arch_map ["accnone" ]
96
+
97
+ pkg_arch_matrix = list (filter (
98
+ lambda x :
99
+ (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ]) ==
100
+ (package_type , gpu_arch_type , gpu_arch_version ),
101
+ release_matrix [os_key ]
102
+ ))
103
+
104
+ if pkg_arch_matrix :
105
+ if package_type != 'libtorch' :
106
+ instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
107
+ else :
108
+ if os_key == OperatingSystem .LINUX .value :
109
+ rel_entry_pre_cxx1 = next (filter (
110
+ lambda x :
111
+ x ["devtoolset" ] == PRE_CXX11_ABI ,
112
+ pkg_arch_matrix
113
+ ), None )
114
+ rel_entry_cxx1_abi = next (filter (
115
+ lambda x :
116
+ x ["devtoolset" ] == CXX11_ABI ,
117
+ pkg_arch_matrix
118
+ ), None )
119
+ if (instr ['versions' ] is not None ):
120
+ instr ['versions' ][LIBTORCH_DWNL_INSTR [PRE_CXX11_ABI ]] = rel_entry_pre_cxx1 ["installation" ]
121
+ instr ['versions' ][LIBTORCH_DWNL_INSTR [CXX11_ABI ]] = rel_entry_cxx1_abi ["installation" ]
122
+ elif os_key == OperatingSystem .WINDOWS .value :
123
+ rel_entry_release = next (filter (
124
+ lambda x :
125
+ x ["libtorch_config" ] == RELEASE ,
126
+ pkg_arch_matrix
127
+ ), None )
128
+ rel_entry_debug = next (filter (
129
+ lambda x :
130
+ x ["libtorch_config" ] == DEBUG ,
131
+ pkg_arch_matrix
132
+ ), None )
133
+ if (instr ['versions' ] is not None ):
134
+ instr ['versions' ][LIBTORCH_DWNL_INSTR [RELEASE ]] = rel_entry_release ["installation" ]
135
+ instr ['versions' ][LIBTORCH_DWNL_INSTR [DEBUG ]] = rel_entry_debug ["installation" ]
136
+
18
137
19
138
def gen_install_matrix (versions ) -> Dict [str , str ]:
20
139
rc = {}
@@ -41,8 +160,34 @@ def gen_install_matrix(versions) -> Dict[str, str]:
41
160
rc [key ] = "<br />" .join (lines )
42
161
return rc
43
162
163
+
44
164
def main ():
165
+ parser = argparse .ArgumentParser ()
166
+ parser .add_argument (
167
+ "--version" ,
168
+ help = "Version to generate the instructions for" ,
169
+ type = str ,
170
+ default = "preview" ,
171
+ )
172
+ parser .add_argument (
173
+ "--autogenerate" ,
174
+ help = "Is this call being initiated from workflow? update published_versions" ,
175
+ type = str ,
176
+ choices = [ENABLE , DISABLE ],
177
+ default = DISABLE ,
178
+ )
179
+
180
+ options = parser .parse_args ()
45
181
versions = read_published_versions ()
182
+
183
+ if options .autogenerate == ENABLE :
184
+ release_matrix = {}
185
+ for osys in OperatingSystem :
186
+ release_matrix [osys .value ] = read_matrix_for_os (osys )
187
+
188
+ update_versions (versions , release_matrix , options .version )
189
+ write_published_versions ()
190
+
46
191
template = read_quick_start_module_template ()
47
192
versions_str = json .dumps (gen_install_matrix (versions ))
48
193
print (template .replace ("{{ installMatrix }}" , versions_str ))
0 commit comments