Skip to content

1.6 blog post #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b371d26
Create 1.6 blog post
brucejlin1 Jul 20, 2020
893b22a
Update 2020-7-20-pytorch-1.6-released.md
andresruizfacebook Jul 22, 2020
1da8ac1
Merge pull request #426 from andresruizfacebook/patch-2
andresruizfacebook Jul 22, 2020
f184249
Update 2020-7-20-pytorch-1.6-released.md
andresruizfacebook Jul 24, 2020
da7a6cd
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 24, 2020
c6cded4
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 24, 2020
d9ca67f
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
0d3ddda
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
49b7db2
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
fb98ec1
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
a03b8c3
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
410e925
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
d2a3609
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
a35e1c0
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
8d200cc
Update 2020-7-20-pytorch-1.6-released.md
brucejlin1 Jul 27, 2020
ed7c1db
Create news-item-4.md
andresruizfacebook Jul 27, 2020
a80ad45
Update news-item-3.md
andresruizfacebook Jul 27, 2020
dd1e7c4
Update news-item-2.md
andresruizfacebook Jul 27, 2020
8b51d5c
Update news-item-1.md
andresruizfacebook Jul 27, 2020
b095a4a
Update news-item-3.md
andresruizfacebook Jul 27, 2020
8f0556f
Update news-item-4.md
brucejlin1 Jul 28, 2020
647f422
Updating Mem Profiler content
brucejlin1 Jul 28, 2020
2fe1157
Rename 2020-7-20-pytorch-1.6-released.md to 2020-7-28-pytorch-1.6-rel…
brucejlin1 Jul 28, 2020
58f8db1
Update 2020-7-28-pytorch-1.6-released.md
brucejlin1 Jul 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _news/news-item-1.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
order: 1
link: https://pytorch.org/blog/updates-improvements-to-pytorch-tutorials/
summary: Click Here to Read About Latest Updates and Improvements to PyTorch Tutorials
link: https://pytorch.org/blog/pytorch-1.6-released/
summary: PyTorch 1.6 released w/ Native AMP Support, Microsoft joins as maintainers for Windows.
---


4 changes: 2 additions & 2 deletions _news/news-item-2.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
order: 2
link: https://pytorch.org/blog/pytorch-1-dot-5-released-with-new-and-updated-apis
summary: PyTorch 1.5 released, new and updated APIs including C++ frontend API parity with Python.
link: https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
summary: Accelerating Training on NVIDIA GPUs with PyTorch Automatic Mixed Precision.
---

4 changes: 2 additions & 2 deletions _news/news-item-3.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
order: 3
link: https://pytorch.org/blog/pytorch-library-updates-new-model-serving-library
summary: PyTorch library updates including new model serving library
link: https://pytorch.org/blog/microsoft-becomes-maintainer-of-the-windows-version-of-pytorch/
summary: Microsoft becomes maintainer of the Windows version of PyTorch.
---

5 changes: 5 additions & 0 deletions _news/news-item-4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
order: 4
link: https://pytorch.org/blog/pytorch-feature-classification-changes/
summary: See the new PyTorch feature classification changes
---
213 changes: 213 additions & 0 deletions _posts/2020-7-28-pytorch-1.6-released.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
---
layout: blog_detail
title: 'PyTorch 1.6 released w/ Native AMP Support, Microsoft joins as maintainers for Windows'
author: Team PyTorch
---

Today, we’re announcing the availability of PyTorch 1.6, along with updated domain libraries. We are also excited to announce the team at [Microsoft is now maintaining Windows builds and binaries](https://pytorch.org/blog/microsoft-becomes-maintainer-of-the-windows-version-of-pytorch) and will also be supporting the community on GitHub as well as the PyTorch Windows discussion forums.

The PyTorch 1.6 release includes a number of new APIs, tools for performance improvement and profiling, as well as major updates to both distributed data parallel (DDP) and remote procedure call (RPC) based distributed training.
A few of the highlights include:

1. Automatic mixed precision (AMP) training is now natively supported and a stable feature (See [here](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) for more details) - thanks for NVIDIA’s contributions;
2. Native TensorPipe support now added for tensor-aware, point-to-point communication primitives built specifically for machine learning;
3. Added support for complex tensors to the frontend API surface;
4. New profiling tools providing tensor-level memory consumption information;
5. Numerous improvements and new features for both distributed data parallel (DDP) training and the remote procedural call (RPC) packages.

Additionally, from this release onward, features will be classified as Stable, Beta and Prototype. Prototype features are not included as part of the binary distribution and are instead available through either building from source, using nightlies or via compiler flag. You can learn more about what this change means in the post [here](https://pytorch.org/blog/pytorch-feature-classification-changes/). You can also find the full release notes [here](https://github.com/pytorch/pytorch/releases).

# Performance & Profiling

## [Stable] Automatic Mixed Precision (AMP) Training

AMP allows users to easily enable automatic mixed precision training enabling higher performance and memory savings of up to 50% on Tensor Core GPUs. Using the natively supported `torch.cuda.amp` API, AMP provides convenience methods for mixed precision, where some operations use the `torch.float32 (float)` datatype and other operations use `torch.float16 (half)`. Some ops, like linear layers and convolutions, are much faster in `float16`. Other ops, like reductions, often require the dynamic range of `float32`. Mixed precision tries to match each op to its appropriate datatype.

* Design doc ([Link](https://github.com/pytorch/pytorch/issues/25081))
* Documentation ([Link](https://pytorch.org/docs/stable/amp.html))
* Usage examples ([Link](https://pytorch.org/docs/stable/notes/amp_examples.html))

## [Beta] Fork/Join Parallelism

This release adds support for a language-level construct as well as runtime support for coarse-grained parallelism in TorchScript code. This support is useful for situations such as running models in an ensemble in parallel, or running bidirectional components of recurrent nets in parallel, and allows the ability to unlock the computational power of parallel architectures (e.g. many-core CPUs) for task level parallelism.

Parallel execution of TorchScript programs is enabled through two primitives: `torch.jit.fork` and `torch.jit.wait`. In the below example, we parallelize execution of `foo`:

```python
import torch
from typing import List

def foo(x):
return torch.neg(x)

@torch.jit.script
def example(x):
futures = [torch.jit.fork(foo, x) for _ in range(100)]
results = [torch.jit.wait(future) for future in futures]
return torch.sum(torch.stack(results))

print(example(torch.ones([])))
```

* Documentation ([Link](https://pytorch.org/docs/stable/jit.html))

## [Beta] Memory Profiler

The `torch.autograd.profiler` API now includes a memory profiler that lets you inspect the tensor memory cost of different operators inside your CPU and GPU models.

Here is an example usage of the API:

```python
import torch
import torchvision.models as models
import torch.autograd.profiler as profiler

model = models.resnet18()
inputs = torch.randn(5, 3, 224, 224)
with profiler.profile(profile_memory=True, record_shapes=True) as prof:
model(inputs)

# NOTE: some columns were removed for brevity
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# --------------------------- --------------- --------------- ---------------
# Name CPU Mem Self CPU Mem Number of Calls
# --------------------------- --------------- --------------- ---------------
# empty 94.79 Mb 94.79 Mb 123
# resize_ 11.48 Mb 11.48 Mb 2
# addmm 19.53 Kb 19.53 Kb 1
# empty_strided 4 b 4 b 1
# conv2d 47.37 Mb 0 b 20
# --------------------------- --------------- --------------- ---------------
```

* PR ([Link](https://github.com/pytorch/pytorch/pull/37775))
* Documentation ([Link](https://pytorch.org/docs/stable/autograd.html#profiler))

# Distributed Training & RPC

## [Beta] TensorPipe backend for RPC

PyTorch 1.6 introduces a new backend for the RPC module which leverages the TensorPipe library, a tensor-aware point-to-point communication primitive targeted at machine learning, intended to complement the current primitives for distributed training in PyTorch (Gloo, MPI, ...) which are collective and blocking. The pairwise and asynchronous nature of TensorPipe lends itself to new networking paradigms that go beyond data parallel: client-server approaches (e.g., parameter server for embeddings, actor-learner separation in Impala-style RL, ...) and model and pipeline parallel training (think GPipe), gossip SGD, etc.

```python
# One-line change needed to opt in
torch.distributed.rpc.init_rpc(
...
backend=torch.distributed.rpc.BackendType.TENSORPIPE,
)

# No changes to the rest of the RPC API
torch.distributed.rpc.rpc_sync(...)
```

* Design doc ([Link](https://github.com/pytorch/pytorch/issues/35251))
* Documentation ([Link](https://pytorch.org/docs/stable/rpc/index.html))

## [Beta] DDP+RPC

PyTorch Distributed supports two powerful paradigms: DDP for full sync data parallel training of models and the RPC framework which allows for distributed model parallelism. Previously, these two features worked independently and users couldn’t mix and match these to try out hybrid parallelism paradigms.

Starting in PyTorch 1.6, we’ve enabled DDP and RPC to work together seamlessly so that users can combine these two techniques to achieve both data parallelism and model parallelism. An example is where users would like to place large embedding tables on parameter servers and use the RPC framework for embedding lookups, but store smaller dense parameters on trainers and use DDP to synchronize the dense parameters. Below is a simple code snippet.

```python
// On each trainer

remote_emb = create_emb(on="ps", ...)
ddp_model = DDP(dense_model)

for data in batch:
with torch.distributed.autograd.context():
res = remote_emb(data)
loss = ddp_model(res)
torch.distributed.autograd.backward([loss])
```

* DDP+RPC Tutorial ([Link](https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html))
* Documentation ([Link](https://pytorch.org/docs/stable/rpc/index.html))
* Usage Examples ([Link](https://github.com/pytorch/examples/pull/800))

## [Beta] RPC - Asynchronous User Functions

RPC Asynchronous User Functions supports the ability to yield and resume on the server side when executing a user-defined function. Prior to this feature, when a callee processes a request, one RPC thread waits until the user function returns. If the user function contains IO (e.g., nested RPC) or signaling (e.g., waiting for another request to unblock), the corresponding RPC thread would sit idle waiting for these events. As a result, some applications have to use a very large number of threads and send additional RPC requests, which can potentially lead to performance degradation. To make a user function yield on such events, applications need to: 1) Decorate the function with the `@rpc.functions.async_execution` decorator; and 2) Let the function return a `torch.futures.Future` and install the resume logic as callbacks on the `Future` object. See below for an example:


```python
@rpc.functions.async_execution
def async_add_chained(to, x, y, z):
return rpc.rpc_async(to, torch.add, args=(x, y)).then(
lambda fut: fut.wait() + z
)

ret = rpc.rpc_sync(
"worker1",
async_add_chained,
args=("worker2", torch.ones(2), 1, 1)
)

print(ret) # prints tensor([3., 3.])
```

* Tutorial for performant batch RPC using Asynchronous User Functions ([Link](https://github.com/pytorch/tutorials/blob/release/1.6/intermediate_source/rpc_async_execution.rst))
* Documentation ([Link](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.functions.async_execution))
* Usage examples ([Link](https://github.com/pytorch/examples/tree/master/distributed/rpc/batch))

# Frontend API Updates

## [Beta] Complex Numbers

The PyTorch 1.6 release brings beta level support for complex tensors including torch.complex64 and torch.complex128 dtypes. A complex number is a number that can be expressed in the form a + bj, where a and b are real numbers, and j is a solution of the equation x^2 = −1. Complex numbers frequently occur in mathematics and engineering, especially in signal processing and the area of complex neural networks is an active area of research. The beta release of complex tensors will support common PyTorch and complex tensor functionality, plus functions needed by Torchaudio, ESPnet and others. While this is an early version of this feature, and we expect it to improve over time, the overall goal is provide a NumPy compatible user experience that leverages PyTorch’s ability to run on accelerators and work with autograd to better support the scientific community.

# Updated Domain Libraries

## torchvision 0.7

torchvision 0.7 introduces two new pretrained semantic segmentation models, [FCN ResNet50](https://arxiv.org/abs/1411.4038) and [DeepLabV3 ResNet50](https://arxiv.org/abs/1706.05587), both trained on COCO and using smaller memory footprints than the ResNet101 backbone. We also introduced support for AMP (Automatic Mixed Precision) autocasting for torchvision models and operators, which automatically selects the floating point precision for different GPU operations to improve performance while maintaining accuracy.

* Release notes ([Link](https://github.com/pytorch/vision/releases))

## torchaudio 0.6

torchaudio now officially supports Windows. This release also introduces a new model module (with wav2letter included), new functionals (contrast, cvm, dcshift, overdrive, vad, phaser, flanger, biquad), datasets (GTZAN, CMU), and a new optional sox backend with support for TorchScript.

* Release notes ([Link](https://github.com/pytorch/audio/releases))

# Additional updates

## HACKATHON

The Global PyTorch Summer Hackathon is back! This year, teams can compete in three categories virtually:

1. **PyTorch Developer Tools:** Tools or libraries designed to improve productivity and efficiency of PyTorch for researchers and developers
2. **Web/Mobile Applications powered by PyTorch:** Applications with web/mobile interfaces and/or embedded devices powered by PyTorch
3. **PyTorch Responsible AI Development Tools:** Tools, libraries, or web/mobile apps for responsible AI development

This is a great opportunity to connect with the community and practice your machine learning skills.

* [Join the hackathon](http://pytorch2020.devpost.com/)
* [Watch educational videos](https://www.youtube.com/pytorch)


## LPCV Challenge

The [2020 CVPR Low-Power Vision Challenge (LPCV) - Online Track for UAV video](https://lpcv.ai/2020CVPR/video-track) submission deadline is coming up shortly. You have until July 31, 2020 to build a system that can discover and recognize characters in video captured by an unmanned aerial vehicle (UAV) accurately using PyTorch and Raspberry Pi 3B+.

## Prototype Features

To reiterate, Prototype features in PyTorch are early features that we are looking to gather feedback on, gauge the usefulness of and improve ahead of graduating them to Beta or Stable. The following features are not part of the PyTorch 1.6 release and instead are available in nightlies with separate docs/tutorials to help facilitate early usage and feedback.

#### Distributed RPC/Profiler
Allow users to profile training jobs that use `torch.distributed.rpc` using the autograd profiler, and remotely invoke the profiler in order to collect profiling information across different nodes. The RFC can be found [here](https://github.com/pytorch/pytorch/issues/39675) and a short recipe on how to use this feature can be found [here](https://github.com/pytorch/tutorials/tree/master/prototype_source).

#### TorchScript Module Freezing
Module Freezing is the process of inlining module parameters and attributes values into the TorchScript internal representation. Parameter and attribute values are treated as final value and they cannot be modified in the frozen module. The PR for this feature can be found [here](https://github.com/pytorch/pytorch/pull/32178) and a short tutorial on how to use this feature can be found [here](https://github.com/pytorch/tutorials/tree/master/prototype_source).

#### Graph Mode Quantization
Eager mode quantization requires users to make changes to their model, including explicitly quantizing activations, module fusion, rewriting use of torch ops with Functional Modules and quantization of functionals are not supported. If we can trace or script the model, then the quantization can be done automatically with graph mode quantization without any of the complexities in eager mode, and it is configurable through a `qconfig_dict`. A tutorial on how to use this feature can be found [here](https://github.com/pytorch/tutorials/tree/master/prototype_source).

#### Quantization Numerical Suite
Quantization is good when it works, but it’s difficult to know what's wrong when it doesn't satisfy the expected accuracy. A prototype is now available for a Numerical Suite that measures comparison statistics between quantized modules and float modules. This is available to test using eager mode and on CPU only with more support coming. A tutorial on how to use this feature can be found [here](https://github.com/pytorch/tutorials/tree/master/prototype_source).


Cheers!

Team PyTorch