Skip to content

Commit bb5a838

Browse files
committed
Testing update quick start moduke
1 parent 877de8e commit bb5a838

File tree

2 files changed

+103
-121
lines changed

2 files changed

+103
-121
lines changed

.github/workflows/update-quick-start-module.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@ on:
77
paths:
88
- .github/workflows/update-quick-start-module.yml
99
- /scripts/gen_quick_start_module.py
10+
- _includes/quick-start-module.js
11+
- _includes/quick_start_local.html
1012
push:
1113
branches:
1214
site
1315
paths:
1416
- .github/workflows/update-quick-start-module.yml
1517
- /scripts/gen_quick_start_module.py
18+
- _includes/quick-start-module.js
19+
- _includes/quick_start_local.html
1620
workflow_dispatch:
1721

1822
jobs:
@@ -56,7 +60,7 @@ jobs:
5660
os: macos
5761
channel: "release"
5862

59-
generate-json-file:
63+
update-quick-start:
6064
needs: [linux-nightly-matrix, windows-nightly-matrix, macos-nightly-matrix,
6165
linux-release-matrix, windows-release-matrix, macos-release-matrix]
6266
runs-on: "ubuntu-18.04"
@@ -79,13 +83,13 @@ jobs:
7983
MACOS_RELEASE_MATRIX: ${{ needs.macos-release-matrix.outputs.matrix }}
8084
run: |
8185
set -ex
82-
# printf '%s\n' "$LINUX_NIGHTLY_MATRIX" > linux_nightly_matrix.json
83-
# printf '%s\n' "$WINDOWS_NIGHTLY_MATRIX" > windows_nightly_matrix.json
86+
printf '%s\n' "$LINUX_NIGHTLY_MATRIX" > linux_nightly_matrix.json
87+
printf '%s\n' "$WINDOWS_NIGHTLY_MATRIX" > windows_nightly_matrix.json
8488
printf '%s\n' "$MACOS_NIGHTLY_MATRIX" > macos_nightly_matrix.json
8589
printf '%s\n' "$LINUX_RELEASE_MATRIX" > linux_release_matrix.json
8690
printf '%s\n' "$WINDOWS_RELEASE_MATRIX" > windows_release_matrix.json
8791
printf '%s\n' "$MACOS_RELEASE_MATRIX" > macos_release_matrix.json
88-
python3 ./scripts/gen_quick_start_module.py --autogenerate enable > assets/quick-start-module.js
92+
python3 ./scripts/gen_quick_start_module.py --autogenerate > assets/quick-start-module.js
8993
rm *_matrix.json
9094
- name: Create Issue if failed
9195
uses: dacbd/create-issue-action@main

scripts/gen_quick_start_module.py

Lines changed: 95 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
#!/usr/bin/env python3
2+
"""
3+
Generates quick start module for https://pytorch.org/get-started/locally/ page
4+
If called from update-quick-start-module.yml workflow (--autogenerate parameter set)
5+
Will output new quick-start-module.js, and new published_version.json file
6+
based on the current release matrix.
7+
If called standalone will generate quick-start-module.js from existing
8+
published_version.json file
9+
"""
10+
211
import json
3-
import os
12+
import copy
413
import argparse
5-
import io
6-
import sys
714
from pathlib import Path
8-
from typing import Dict, Set, List, Iterable
15+
from typing import Dict
916
from enum import Enum
1017

11-
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18+
BASE_DIR = Path(__file__).parent.parent
1219

1320
class OperatingSystem(Enum):
1421
LINUX: str = "linux"
@@ -27,7 +34,6 @@ class OperatingSystem(Enum):
2734
# TBD drive the mapping via:
2835
# 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
2936
# 2. Possibility to override the scanning algorithm with arguments passed from workflow
30-
3137
acc_arch_ver_map = {
3238
"nightly": {
3339
"accnone": ("cpu", ""),
@@ -50,106 +56,87 @@ class OperatingSystem(Enum):
5056
DEBUG: "Download here (Debug version):",
5157
}
5258

59+
def load_json_from_basedir(filename: str):
60+
try:
61+
with open(BASE_DIR / filename) as fptr:
62+
return json.load(fptr)
63+
except FileNotFoundError as exc:
64+
raise ImportError(f"File {filename} not found error: {exc.strerror}") from exc
65+
except json.JSONDecodeError as exc:
66+
raise ImportError(f"Invalid JSON {filename}") from exc
67+
5368
def read_published_versions():
54-
with open(os.path.join(BASE_DIR, "published_versions.json")) as fp:
55-
return json.load(fp)
69+
return load_json_from_basedir("published_versions.json")
5670

5771
def write_published_versions(versions):
58-
with open(os.path.join(BASE_DIR, "published_versions.json"), "w") as outfile:
59-
json.dump(versions, outfile, indent=2)
60-
61-
def read_matrix_for_os(osys: OperatingSystem, value: str):
62-
try:
63-
with open(os.path.join(BASE_DIR, f"{osys.value}_{value}_matrix.json")) as fp:
64-
return json.load(fp)["include"]
65-
except FileNotFoundError as e:
66-
raise ImportError(f"Release matrix not found for: {osys.value} error: {e.strerror}") from e
72+
with open(BASE_DIR / "published_versions.json", "w") as outfile:
73+
json.dump(versions, outfile, indent=2)
6774

75+
def read_matrix_for_os(osys: OperatingSystem, channel: str):
76+
jsonfile = load_json_from_basedir(f"{osys.value}_{channel}_matrix.json")
77+
return jsonfile["include"]
6878

6979
def read_quick_start_module_template():
70-
with open(os.path.join(BASE_DIR, "_includes", "quick-start-module.js")) as fp:
71-
return fp.read()
72-
80+
with open(BASE_DIR / "_includes" / "quick-start-module.js") as fptr:
81+
return fptr.read()
82+
83+
def get_package_type(pkg_key: str, os_key: OperatingSystem) -> str:
84+
if pkg_key != "pip":
85+
return pkg_key
86+
return "manywheel" if os_key == OperatingSystem.LINUX.value else "wheel"
87+
88+
def get_gpu_info(acc_key, instr, acc_arch_map):
89+
gpu_arch_type, gpu_arch_version = acc_arch_map[acc_key]
90+
if DEFAULT in instr:
91+
gpu_arch_type, gpu_arch_version = acc_arch_map["accnone"]
92+
return (gpu_arch_type, gpu_arch_version)
93+
94+
# This method is used for generating new published_versions.json file
95+
# It will modify versions json object with installation instructions
96+
# Provided by generate install matrix Github Workflow, stored in release_matrix
97+
# json object.
7398
def update_versions(versions, release_matrix, release_version):
74-
version_map = {
75-
"preview": "preview",
76-
}
77-
version = ""
99+
version = "preview"
78100
acc_arch_map = acc_arch_ver_map[release_version]
79101

80-
if(release_version == "nightly"):
81-
version = "preview"
82-
else:
102+
if release_version != "nightly":
83103
version = release_matrix[OperatingSystem.LINUX.value][0]["stable_version"]
84-
85-
# Generating for a specific version
86-
if(version != "preview"):
87-
version_map = {
88-
version: version,
89-
}
90104
if version not in versions["versions"]:
91-
import copy
92-
new_version = copy.deepcopy(versions["versions"]["preview"])
93-
versions["versions"][version] = new_version
105+
versions["versions"][version] = copy.deepcopy(versions["versions"]["preview"])
94106
versions["latest_stable"] = version
95107

96108
# Perform update of the json file from release matrix
97-
for ver, ver_key in version_map.items():
98-
for os_key, os_vers in versions["versions"][ver_key].items():
99-
for pkg_key, pkg_vers in os_vers.items():
100-
for acc_key, instr in pkg_vers.items():
101-
102-
package_type = pkg_key
103-
if pkg_key == 'pip':
104-
package_type = 'manywheel' if os_key == OperatingSystem.LINUX.value else 'wheel'
105-
106-
gpu_arch_type, gpu_arch_version = acc_arch_map[acc_key]
107-
if(DEFAULT in instr):
108-
gpu_arch_type, gpu_arch_version = acc_arch_map["accnone"]
109-
110-
pkg_arch_matrix = list(filter(
111-
lambda x:
112-
(x["package_type"], x["gpu_arch_type"], x["gpu_arch_version"]) ==
113-
(package_type, gpu_arch_type, gpu_arch_version),
114-
release_matrix[os_key]
115-
))
116-
117-
if pkg_arch_matrix:
118-
if package_type != 'libtorch':
119-
instr["command"] = pkg_arch_matrix[0]["installation"]
120-
else:
121-
if os_key == OperatingSystem.LINUX.value:
122-
rel_entry_pre_cxx1 = next(filter(
123-
lambda x:
124-
x["devtoolset"] == PRE_CXX11_ABI,
125-
pkg_arch_matrix
126-
), None)
127-
rel_entry_cxx1_abi = next(filter(
128-
lambda x:
129-
x["devtoolset"] == CXX11_ABI,
130-
pkg_arch_matrix
131-
), None)
132-
if(instr['versions'] is not None):
133-
instr['versions'][LIBTORCH_DWNL_INSTR[PRE_CXX11_ABI]] = rel_entry_pre_cxx1["installation"]
134-
instr['versions'][LIBTORCH_DWNL_INSTR[CXX11_ABI]] = rel_entry_cxx1_abi["installation"]
135-
elif os_key == OperatingSystem.WINDOWS.value:
136-
rel_entry_release = next(filter(
137-
lambda x:
138-
x["libtorch_config"] == RELEASE,
139-
pkg_arch_matrix
140-
), None)
141-
rel_entry_debug = next(filter(
142-
lambda x:
143-
x["libtorch_config"] == DEBUG,
144-
pkg_arch_matrix
145-
), None)
146-
if(instr['versions'] is not None):
147-
instr['versions'][LIBTORCH_DWNL_INSTR[RELEASE]] = rel_entry_release["installation"]
148-
instr['versions'][LIBTORCH_DWNL_INSTR[DEBUG]] = rel_entry_debug["installation"]
149-
150-
109+
for os_key, os_vers in versions["versions"][version].items():
110+
for pkg_key, pkg_vers in os_vers.items():
111+
for acc_key, instr in pkg_vers.items():
112+
package_type = get_package_type(pkg_key, os_key)
113+
gpu_arch_type, gpu_arch_version = get_gpu_info(acc_key, instr, acc_arch_map)
114+
115+
pkg_arch_matrix = [
116+
x for x in release_matrix[os_key]
117+
if (x["package_type"], x["gpu_arch_type"], x["gpu_arch_version"]) ==
118+
(package_type, gpu_arch_type, gpu_arch_version)
119+
]
120+
121+
if pkg_arch_matrix:
122+
if package_type != "libtorch":
123+
instr["command"] = pkg_arch_matrix[0]["installation"]
124+
else:
125+
if os_key == OperatingSystem.LINUX.value:
126+
rel_entry_dict = {x["devtoolset"]: x["installation"] for x in pkg_arch_matrix}
127+
if instr["versions"] is not None:
128+
for ver in [PRE_CXX11_ABI, CXX11_ABI]:
129+
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = rel_entry_dict[ver]
130+
elif os_key == OperatingSystem.WINDOWS.value:
131+
rel_entry_dict = {x["libtorch_config"]: x["installation"] for x in pkg_arch_matrix}
132+
if instr["versions"] is not None:
133+
for ver in [RELEASE, DEBUG]:
134+
instr["versions"][LIBTORCH_DWNL_INSTR[ver]] = rel_entry_dict[ver]
135+
136+
# This method is used for generating new quick-start-module.js
137+
# from the versions json object
151138
def gen_install_matrix(versions) -> Dict[str, str]:
152-
rc = {}
139+
result = {}
153140
version_map = {
154141
"preview": "preview",
155142
"stable": versions["latest_stable"],
@@ -158,36 +145,30 @@ def gen_install_matrix(versions) -> Dict[str, str]:
158145
for os_key, os_vers in versions["versions"][ver_key].items():
159146
for pkg_key, pkg_vers in os_vers.items():
160147
for acc_key, instr in pkg_vers.items():
161-
extra_key = 'python' if pkg_key != 'libtorch' else 'cplusplus'
162-
key = f"{ver},{pkg_key},{os_key},{acc_key},{extra_key}"
163-
note = instr["note"]
164-
lines = [note] if note is not None else []
165-
if pkg_key == "libtorch":
166-
ivers = instr["versions"]
167-
if ivers is not None:
168-
lines += [f"{lab}<br /><a href='{val}'>{val}</a>" for (lab, val) in ivers.items()]
169-
else:
170-
command = instr["command"]
171-
if command is not None:
172-
lines.append(command)
173-
rc[key] = "<br />".join(lines)
174-
return rc
148+
extra_key = 'python' if pkg_key != 'libtorch' else 'cplusplus'
149+
key = f"{ver},{pkg_key},{os_key},{acc_key},{extra_key}"
150+
note = instr["note"]
151+
lines = [note] if note is not None else []
152+
if pkg_key == "libtorch":
153+
ivers = instr["versions"]
154+
if ivers is not None:
155+
lines += [f"{lab}<br /><a href='{val}'>{val}</a>" for (lab, val) in ivers.items()]
156+
else:
157+
command = instr["command"]
158+
if command is not None:
159+
lines.append(command)
160+
result[key] = "<br />".join(lines)
161+
return result
175162

176163
def main():
177164
parser = argparse.ArgumentParser()
178-
parser.add_argument(
179-
"--autogenerate",
180-
help="Is this call being initiated from workflow? update published_versions",
181-
type=str,
182-
choices=[ENABLE, DISABLE],
183-
default=ENABLE,
184-
)
165+
parser.add_argument('--autogenerate', dest='autogenerate', action='store_true')
166+
parser.set_defaults(autogenerate=False)
185167

186168
options = parser.parse_args()
187169
versions = read_published_versions()
188170

189-
190-
if options.autogenerate == ENABLE:
171+
if options.autogenerate:
191172
release_matrix = {}
192173
for val in ("nightly", "release"):
193174
release_matrix[val] = {}
@@ -199,14 +180,11 @@ def main():
199180

200181
write_published_versions(versions)
201182

202-
203183
template = read_quick_start_module_template()
204184
versions_str = json.dumps(gen_install_matrix(versions))
205185
template = template.replace("{{ installMatrix }}", versions_str)
206186
template = template.replace("{{ VERSION }}", f"\"Stable ({versions['latest_stable']})\"")
207187
print(template.replace("{{ ACC ARCH MAP }}", json.dumps(acc_arch_ver_map)))
208188

209-
210-
211189
if __name__ == "__main__":
212190
main()

0 commit comments

Comments
 (0)