Skip to content

Commit 2151259

Browse files
authored
Merge branch 'main' into main
2 parents 366fcbb + 37462ab commit 2151259

File tree

9 files changed

+242
-10
lines changed

9 files changed

+242
-10
lines changed
91 KB
Loading

advanced_source/neural_style_tutorial.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
import matplotlib.pyplot as plt
5757

5858
import torchvision.transforms as transforms
59-
import torchvision.models as models
59+
from torchvision.models import vgg19, VGG19_Weights
6060

6161
import copy
6262

@@ -262,7 +262,7 @@ def forward(self, input):
262262
# network to evaluation mode using ``.eval()``.
263263
#
264264

265-
cnn = models.vgg19(pretrained=True).features.eval()
265+
cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
266266

267267

268268

docathon-leaderboard.md

+43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,46 @@
1+
# 🎉 Docathon H2 2023 Leaderboard 🎉
2+
3+
This is the list of the docathon contributors that have participated and contributed to the H2 2023 PyTorch docathon.
4+
A big shout out to everyone who have participated! We have awarded points for each merged PR.
5+
For the **easy** label, we have awarded 2 points. For the **medium** label, we have awarded 5 points.
6+
For the **advanced** label, we have awarded 10 points. In some cases, we have awarded half credit for the PRs that
7+
were not merged or issues that have been closed without a merged PR. Thank you all for your awesome contributions! 🎉
8+
9+
| Author | Points | PR |
10+
|--- | --- | ---|
11+
| ahoblitz | 25 | https://github.com/pytorch/pytorch/pull/112992, https://github.com/pytorch/tutorials/pull/2662, https://github.com/pytorch/tutorials/pull/2647, https://github.com/pytorch/tutorials/pull/2642, https://github.com/pytorch/tutorials/pull/2640, https://github.com/pytorch/pytorch/pull/113092, https://github.com/pytorch/pytorch/pull/113348 |
12+
| ChanBong | 22 | https://github.com/pytorch/pytorch/pull/113337, https://github.com/pytorch/pytorch/pull/113336, https://github.com/pytorch/pytorch/pull/113335, https://github.com/pytorch/tutorials/pull/2644, https://github.com/pytorch/tutorials/pull/2639 |
13+
| alperenunlu | 22 | https://github.com/pytorch/pytorch/pull/113260, https://github.com/pytorch/tutorials/pull/2673, https://github.com/pytorch/tutorials/pull/2660, https://github.com/pytorch/tutorials/pull/2656, https://github.com/pytorch/tutorials/pull/2649, https://github.com/pytorch/pytorch/pull/113505, https://github.com/pytorch/pytorch/pull/113218, https://github.com/pytorch/pytorch/pull/113505 |
14+
| spzala | 22 | https://github.com/pytorch/pytorch/pull/113200, https://github.com/pytorch/pytorch/pull/112693, https://github.com/pytorch/tutorials/pull/2667, https://github.com/pytorch/tutorials/pull/2635 |
15+
| bjhargrave | 21 | https://github.com/pytorch/pytorch/pull/113358, https://github.com/pytorch/pytorch/pull/113206, https://github.com/pytorch/pytorch/pull/112786, https://github.com/pytorch/tutorials/pull/2661, https://github.com/pytorch/tutorials/pull/1272 |
16+
| zabboud | 21 | https://github.com/pytorch/pytorch/pull/113233, https://github.com/pytorch/pytorch/pull/113227, https://github.com/pytorch/pytorch/pull/113177, https://github.com/pytorch/pytorch/pull/113219, https://github.com/pytorch/pytorch/pull/113311 |
17+
| nvs-abhilash | 20 | https://github.com/pytorch/pytorch/pull/113241, https://github.com/pytorch/pytorch/pull/112765, https://github.com/pytorch/pytorch/pull/112695, https://github.com/pytorch/pytorch/pull/112657 |
18+
| guptaaryan16 | 19 | https://github.com/pytorch/pytorch/pull/112817, https://github.com/pytorch/pytorch/pull/112735, https://github.com/pytorch/tutorials/pull/2674, https://github.com/pytorch/pytorch/pull/113196, https://github.com/pytorch/pytorch/pull/113532 |
19+
| min-jean-cho | 17 | https://github.com/pytorch/pytorch/pull/113195, https://github.com/pytorch/pytorch/pull/113183, https://github.com/pytorch/pytorch/pull/113178, https://github.com/pytorch/pytorch/pull/113109, https://github.com/pytorch/pytorch/pull/112892 |
20+
| markstur | 14 | https://github.com/pytorch/pytorch/pull/113250, https://github.com/pytorch/tutorials/pull/2643, https://github.com/pytorch/tutorials/pull/2638, https://github.com/pytorch/tutorials/pull/2636 |
21+
| RustyGrackle | 13 | https://github.com/pytorch/pytorch/pull/113371, https://github.com/pytorch/pytorch/pull/113266, https://github.com/pytorch/pytorch/pull/113435 |
22+
| Viditagarwal7479 | 12 | https://github.com/pytorch/pytorch/pull/112860, https://github.com/pytorch/tutorials/pull/2659, https://github.com/pytorch/tutorials/pull/2671 |
23+
| kiszk | 10 | https://github.com/pytorch/pytorch/pull/113523, https://github.com/pytorch/pytorch/pull/112751 |
24+
| awaelchli | 10 | https://github.com/pytorch/pytorch/pull/113216, https://github.com/pytorch/pytorch/pull/112674 |
25+
| pilot-j | 10 | https://github.com/pytorch/pytorch/pull/112964, https://github.com/pytorch/pytorch/pull/112856 |
26+
| krishnakalyan3 | 7 | https://github.com/pytorch/tutorials/pull/2653, https://github.com/pytorch/tutorials/pull/1235, https://github.com/pytorch/tutorials/pull/1705 |
27+
| ash-01xor | 5 | https://github.com/pytorch/pytorch/pull/113511 |
28+
| IvanLauLinTiong | 5 | https://github.com/pytorch/pytorch/pull/113052 |
29+
| Senthi1Kumar | 5 | https://github.com/pytorch/pytorch/pull/113021 |
30+
| ooooo-create | 5 | https://github.com/pytorch/pytorch/pull/112953 |
31+
| stanleyedward | 5 | https://github.com/pytorch/pytorch/pull/112864, https://github.com/pytorch/pytorch/pull/112617 |
32+
| leslie-fang-intel | 5 | https://github.com/pytorch/tutorials/pull/2668 |
33+
| measty | 5 | https://github.com/pytorch/tutorials/pull/2675 |
34+
| Hhhhhhao | 5 | https://github.com/pytorch/tutorials/pull/2676 |
35+
| andrewashere | 3 | https://github.com/pytorch/pytorch/pull/112721 |
36+
| aalhendi | 3 | https://github.com/pytorch/pytorch/pull/112947 |
37+
| sitamgithub-MSIT | 3 | https://github.com/pytorch/pytorch/pull/113264 |
38+
| Jarlaze | 3 | https://github.com/pytorch/pytorch/pull/113531 |
39+
| jingxu10 | 2 | https://github.com/pytorch/tutorials/pull/2657 |
40+
| cirquit | 2 | https://github.com/pytorch/tutorials/pull/2529 |
41+
| prithviraj-maurya | 1 | https://github.com/pytorch/tutorials/pull/2652 |
42+
| MirMustafaAli | 1 | https://github.com/pytorch/tutorials/pull/2645 |
43+
144
# 🎉 Docathon H1 2023 Leaderboard 🎉
245
This is the list of the docathon contributors that have participated and contributed to the PyTorch docathon.
346
A big shout out to everyone who have participated! We have awarded points for each merged PR.

index.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ What's new in PyTorch tutorials?
293293
:header: Introduction to ONNX Registry
294294
:card_description: Demonstrate end-to-end how to address unsupported operators by using ONNX Registry.
295295
:image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
296-
:link: advanced/onnx_registry_tutorial.html
296+
:link: advanced/onnx_registry_tutorial.html
297297
:tags: Production,ONNX,Backends
298298

299299
.. Reinforcement Learning
@@ -1050,6 +1050,7 @@ Additional Resources
10501050
intermediate/scaled_dot_product_attention_tutorial
10511051
beginner/knowledge_distillation_tutorial
10521052

1053+
10531054
.. toctree::
10541055
:maxdepth: 2
10551056
:includehidden:

intermediate_source/FSDP_tutorial.rst

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ At a high level FSDP works as follow:
4646
* Run reduce_scatter to sync gradients
4747
* Discard parameters.
4848

49+
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
50+
51+
.. figure:: /_static/img/distributed/fsdp_sharding.png
52+
:width: 100%
53+
:align: center
54+
:alt: FSDP allreduce
55+
56+
FSDP Allreduce
57+
4958
How to use FSDP
5059
--------------
5160
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.

intermediate_source/reinforcement_q_learning.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
99
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
10-
on the CartPole-v1 task from `Gymnasium <https://www.gymnasium.farama.org>`__.
10+
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
1111
1212
**Task**
1313
@@ -283,7 +283,7 @@ def select_action(state):
283283
# t.max(1) will return the largest column value of each row.
284284
# second column on max result is index of where max element was
285285
# found, so we pick action with the larger expected reward.
286-
return policy_net(state).max(1)[1].view(1, 1)
286+
return policy_net(state).max(1).indices.view(1, 1)
287287
else:
288288
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
289289

@@ -360,12 +360,12 @@ def optimize_model():
360360

361361
# Compute V(s_{t+1}) for all next states.
362362
# Expected values of actions for non_final_next_states are computed based
363-
# on the "older" target_net; selecting their best reward with max(1)[0].
363+
# on the "older" target_net; selecting their best reward with max(1).values
364364
# This is merged based on the mask, such that we'll have either the expected
365365
# state value or 0 in case the state was final.
366366
next_state_values = torch.zeros(BATCH_SIZE, device=device)
367367
with torch.no_grad():
368-
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
368+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
369369
# Compute the expected Q values
370370
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
371371

prototype_source/pt2e_quant_ptq_x86_inductor.rst

+29-3
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,37 @@ After we get the quantized model, we will further lower it to the inductor backe
165165

166166
::
167167

168-
optimized_model = torch.compile(converted_model)
168+
with torch.no_grad():
169+
optimized_model = torch.compile(converted_model)
170+
171+
# Running some benchmark
172+
optimized_model(*example_inputs)
169173

170-
# Running some benchmark
171-
optimized_model(*example_inputs)
174+
In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance,
175+
a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence
176+
of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through
177+
subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance.
178+
The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the
179+
script within the BFloat16 Autocast context.
180+
181+
::
172182

183+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True), torch.no_grad():
184+
# Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into Inductor CPP Backend,
185+
# For operators such as QConvolution and QLinear:
186+
# * The input data type is consistently defined as int8, attributable to the presence of a pair
187+
of quantization and dequantization nodes inserted at the input.
188+
# * The computation precision remains at int8.
189+
# * The output data type may vary, being either int8 or BFloat16, contingent on the presence
190+
# of a pair of quantization and dequantization nodes at the output.
191+
# For non-quantizable pointwise operators, the data type will be inherited from the previous node,
192+
# potentially resulting in a data type of BFloat16 in this scenario.
193+
# For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8
194+
# data type for both input and output.
195+
optimized_model = torch.compile(converted_model)
196+
197+
# Running some benchmark
198+
optimized_model(*example_inputs)
173199

174200
Put all these codes together, we will have the toy example code.
175201
Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``.

recipes_source/recipes_index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
324324
:link: ../recipes/DCP_tutorial.html
325325
:tags: Distributed-Training
326326

327+
.. TorchServe
328+
329+
.. customcarditem::
330+
:header: Deploying a PyTorch Stable Diffusion model as a Vertex AI Endpoint
331+
:card_description: Learn how to deploy model in Vertex AI with TorchServe
332+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
333+
:link: ../recipes/torchserve_vertexai_tutorial.html
334+
:tags: Production
335+
327336
.. End of tutorial card section
328337
329338
.. raw:: html
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
Deploying a PyTorch Stable Diffusion model as a Vertex AI Endpoint
2+
==================================================================
3+
4+
Deploying large models, like Stable Diffusion, can be challenging and time-consuming.
5+
6+
In this recipe, we will show how you can streamline the deployment of a PyTorch Stable Diffusion
7+
model by leveraging Vertex AI.
8+
9+
PyTorch is the framework used by Stability AI on Stable
10+
Diffusion v1.5. Vertex AI is a fully-managed machine learning platform with tools and
11+
infrastructure designed to help ML practitioners accelerate and scale ML in production with
12+
the benefit of open-source frameworks like PyTorch.
13+
14+
In four steps you can deploy a PyTorch Stable Diffusion model (v1.5).
15+
16+
Deploying your Stable Diffusion model on a Vertex AI Endpoint can be done in four steps:
17+
18+
* Create a custom TorchServe handler.
19+
20+
* Upload model artifacts to Google Cloud Storage (GCS).
21+
22+
* Create a Vertex AI model with the model artifacts and a prebuilt PyTorch container image.
23+
24+
* Deploy the Vertex AI model onto an endpoint.
25+
26+
Let’s have a look at each step in more detail. You can follow and implement the steps using the
27+
`Notebook example <https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/vertex_endpoints/torchserve/dreambooth_stablediffusion.ipynb>`__.
28+
29+
NOTE: Please keep in mind that this recipe requires a billable Vertex AI as explained in more details in the notebook example.
30+
31+
Create a custom TorchServe handler
32+
----------------------------------
33+
34+
TorchServe is an easy and flexible tool for serving PyTorch models. The model deployed to Vertex AI
35+
uses TorchServe to handle requests and return responses from the model.
36+
You must create a custom TorchServe handler to include in the model artifacts uploaded to Vertex AI. Include the handler file in the
37+
directory with the other model artifacts, like this: `model_artifacts/handler.py`.
38+
39+
After creating the handler file, you must package the handler as a model archiver (MAR) file.
40+
The output file must be named `model.mar`.
41+
42+
43+
.. code:: shell
44+
45+
!torch-model-archiver \
46+
-f \
47+
--model-name <your_model_name> \
48+
--version 1.0 \
49+
--handler model_artifacts/handler.py \
50+
--export-path model_artifacts
51+
52+
Upload model artifacts to Google Cloud Storage (GCS)
53+
----------------------------------------------------
54+
55+
In this step we are uploading
56+
`model artifacts <https://github.com/pytorch/serve/tree/master/model-archiver#artifact-details>`__
57+
to GCS, like the model file or handler. The advantage of storing your artifacts on GCS is that you can
58+
track the artifacts in a central bucket.
59+
60+
61+
.. code:: shell
62+
63+
BUCKET_NAME = "your-bucket-name-unique" # @param {type:"string"}
64+
BUCKET_URI = f"gs://{BUCKET_NAME}/"
65+
66+
# Will copy the artifacts into the bucket
67+
!gsutil cp -r model_artifacts $BUCKET_URI
68+
69+
Create a Vertex AI model with the model artifacts and a prebuilt PyTorch container image
70+
----------------------------------------------------------------------------------------
71+
72+
Once you've uploaded the model artifacts into a GCS bucket, you can upload your PyTorch model to
73+
`Vertex AI Model Registry <https://cloud.google.com/vertex-ai/docs/model-registry/introduction>`__.
74+
From the Vertex AI Model Registry, you have an overview of your models
75+
so you can better organize, track, and train new versions. For this you can use the
76+
`Vertex AI SDK <https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk>`__
77+
and this
78+
`pre-built PyTorch container <https://cloud.google.com/blog/products/ai-machine-learning/prebuilt-containers-with-pytorch-and-vertex-ai>`__.
79+
80+
81+
.. code:: shell
82+
83+
from google.cloud import aiplatform as vertexai
84+
PYTORCH_PREDICTION_IMAGE_URI = (
85+
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest"
86+
)
87+
MODEL_DISPLAY_NAME = "stable_diffusion_1_5-unique"
88+
MODEL_DESCRIPTION = "stable_diffusion_1_5 container"
89+
90+
vertexai.init(project='your_project', location='us-central1', staging_bucket=BUCKET_NAME)
91+
92+
model = aiplatform.Model.upload(
93+
display_name=MODEL_DISPLAY_NAME,
94+
description=MODEL_DESCRIPTION,
95+
serving_container_image_uri=PYTORCH_PREDICTION_IMAGE_URI,
96+
artifact_uri=BUCKET_URI,
97+
)
98+
99+
Deploy the Vertex AI model onto an endpoint
100+
-------------------------------------------
101+
102+
Once the model has been uploaded to Vertex AI Model Registry you can then take it and deploy
103+
it to an Vertex AI Endpoint. For this you can use the Console or the Vertex AI SDK. In this
104+
example you will deploy the model on a NVIDIA Tesla P100 GPU and n1-standard-8 machine. You can
105+
specify your machine type.
106+
107+
108+
.. code:: shell
109+
110+
endpoint = aiplatform.Endpoint.create(display_name=ENDPOINT_DISPLAY_NAME)
111+
112+
model.deploy(
113+
endpoint=endpoint,
114+
deployed_model_display_name=MODEL_DISPLAY_NAME,
115+
machine_type="n1-standard-8",
116+
accelerator_type="NVIDIA_TESLA_P100",
117+
accelerator_count=1,
118+
traffic_percentage=100,
119+
deploy_request_timeout=1200,
120+
sync=True,
121+
)
122+
123+
If you follow this
124+
`notebook <https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/vertex_endpoints/torchserve/dreambooth_stablediffusion.ipynb>`__
125+
you can also get online predictions using the Vertex AI SDK as shown in the following snippet.
126+
127+
128+
.. code:: shell
129+
130+
instances = [{"prompt": "An examplePup dog with a baseball jersey."}]
131+
response = endpoint.predict(instances=instances)
132+
133+
with open("img.jpg", "wb") as g:
134+
g.write(base64.b64decode(response.predictions[0]))
135+
136+
display.Image("img.jpg")
137+
138+
Create a Vertex AI model with the model artifacts and a prebuilt PyTorch container image
139+
140+
More resources
141+
--------------
142+
143+
This tutorial was created using the vendor documentation. To refer to the original documentation on the vendor site, please see
144+
`torchserve example <https://cloud.google.com/blog/products/ai-machine-learning/get-your-genai-model-going-in-four-easy-steps>`__.

0 commit comments

Comments
 (0)