Skip to content

Commit a66464b

Browse files
authored
A tutorial on pin_memory and non_blocking usage (#2983)
1 parent c3882db commit a66464b

11 files changed

+770
-22
lines changed

.ci/docker/requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ tensorboard
2828
jinja2==3.1.3
2929
pytorch-lightning
3030
torchx
31-
torchrl==0.3.0
32-
tensordict==0.3.0
31+
# TODO: use stable 0.5 when released
32+
-e git+https://github.com/pytorch/rl.git#egg=torchrl
33+
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
3334
ax-platform
3435
nbformat>==5.9.2
3536
datasets

_static/img/pinmem/pinmem.png

72 KB
Loading
81.2 KB
Loading
81.4 KB
Loading
85.4 KB
Loading
90.6 KB
Loading

advanced_source/coding_ddpg.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
# Later, we will see how the target parameters should be updated in TorchRL.
183183
#
184184

185-
from tensordict.nn import TensorDictModule
185+
from tensordict.nn import TensorDictModule, TensorDictSequential
186186

187187

188188
def _init(
@@ -290,12 +290,11 @@ def _loss_actor(
290290
) -> torch.Tensor:
291291
td_copy = tensordict.select(*self.actor_in_keys)
292292
# Get an action from the actor network: since we made it functional, we need to pass the params
293-
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
293+
with self.actor_network_params.to_module(self.actor_network):
294+
td_copy = self.actor_network(td_copy)
294295
# get the value associated with that action
295-
td_copy = self.value_network(
296-
td_copy,
297-
params=self.value_network_params.detach(),
298-
)
296+
with self.value_network_params.detach().to_module(self.value_network):
297+
td_copy = self.value_network(td_copy)
299298
return -td_copy.get("state_action_value")
300299

301300

@@ -317,7 +316,8 @@ def _loss_value(
317316
td_copy = tensordict.clone()
318317

319318
# V(s, a)
320-
self.value_network(td_copy, params=self.value_network_params)
319+
with self.value_network_params.to_module(self.value_network):
320+
self.value_network(td_copy)
321321
pred_val = td_copy.get("state_action_value").squeeze(-1)
322322

323323
# we manually reconstruct the parameters of the actor-critic, where the first
@@ -332,9 +332,8 @@ def _loss_value(
332332
batch_size=self.target_actor_network_params.batch_size,
333333
device=self.target_actor_network_params.device,
334334
)
335-
target_value = self.value_estimator.value_estimate(
336-
tensordict, target_params=target_params
337-
).squeeze(-1)
335+
with target_params.to_module(self.actor_critic):
336+
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
338337

339338
# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
340339
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
@@ -717,7 +716,7 @@ def get_env_stats():
717716
ActorCriticWrapper,
718717
DdpgMlpActor,
719718
DdpgMlpQNet,
720-
OrnsteinUhlenbeckProcessWrapper,
719+
OrnsteinUhlenbeckProcessModule,
721720
ProbabilisticActor,
722721
TanhDelta,
723722
ValueOperator,
@@ -776,15 +775,18 @@ def make_ddpg_actor(
776775
# Exploration
777776
# ~~~~~~~~~~~
778777
#
779-
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
778+
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
780779
# exploration module, as suggested in the original paper.
781780
# Let's define the number of frames before OU noise reaches its minimum value
782781
annealing_frames = 1_000_000
783782

784-
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
783+
actor_model_explore = TensorDictSequential(
785784
actor,
786-
annealing_num_steps=annealing_frames,
787-
).to(device)
785+
OrnsteinUhlenbeckProcessModule(
786+
spec=actor.spec.clone(),
787+
annealing_num_steps=annealing_frames,
788+
).to(device),
789+
)
788790
if device == torch.device("cpu"):
789791
actor_model_explore.share_memory()
790792

@@ -1168,7 +1170,7 @@ def ceil_div(x, y):
11681170
)
11691171

11701172
# update the exploration strategy
1171-
actor_model_explore.step(current_frames)
1173+
actor_model_explore[1].step(current_frames)
11721174

11731175
collector.shutdown()
11741176
del collector

en-wordlist.txt

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
ACL
23
ADI
34
AOT
@@ -50,6 +51,7 @@ DDP
5051
DDPG
5152
DDQN
5253
DLRM
54+
DMA
5355
DNN
5456
DQN
5557
DataLoaders
@@ -68,6 +70,8 @@ Ecker
6870
ExportDB
6971
FC
7072
FGSM
73+
tensordict
74+
DataLoader's
7175
FLAVA
7276
FSDP
7377
FX
@@ -139,6 +143,7 @@ MKLDNN
139143
MLP
140144
MLPs
141145
MNIST
146+
MPS
142147
MUC
143148
MacBook
144149
MacOS
@@ -219,6 +224,7 @@ STR
219224
SVE
220225
SciPy
221226
Sequentials
227+
Sharding
222228
Sigmoid
223229
SoTA
224230
Sohn
@@ -254,6 +260,7 @@ VLDB
254260
VQA
255261
VS Code
256262
ViT
263+
Volterra
257264
WMT
258265
WSI
259266
WSIs
@@ -336,11 +343,11 @@ dataset’s
336343
deallocation
337344
decompositions
338345
decorrelated
339-
devicemesh
340346
deserialize
341347
deserialized
342348
desynchronization
343349
deterministically
350+
devicemesh
344351
dimensionality
345352
dir
346353
discontiguous
@@ -384,6 +391,7 @@ hessian
384391
hessians
385392
histoencoder
386393
histologically
394+
homonymous
387395
hotspot
388396
hvp
389397
hyperparameter
@@ -459,6 +467,7 @@ optimizer's
459467
optimizers
460468
otsu
461469
overfitting
470+
pageable
462471
parallelizable
463472
parallelization
464473
parametrization
@@ -522,7 +531,6 @@ runtime
522531
runtimes
523532
scalable
524533
sharded
525-
Sharding
526534
softmax
527535
sparsified
528536
sparsifier
@@ -609,4 +617,4 @@ warmstarting
609617
warmup
610618
webp
611619
wsi
612-
wsis
620+
wsis

index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Welcome to PyTorch Tutorials
33

44
**What's new in PyTorch tutorials?**
55

6+
* `A guide on good usage of non_blocking and pin_memory() in PyTorch <https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html>`__
67
* `Introduction to Distributed Pipeline Parallelism <https://pytorch.org/tutorials/intermediate/pipelining_tutorial.html>`__
78
* `Introduction to Libuv TCPStore Backend <https://pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html>`__
89
* `Asynchronous Saving with Distributed Checkpoint (DCP) <https://pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html>`__
@@ -93,6 +94,13 @@ Welcome to PyTorch Tutorials
9394
:link: intermediate/tensorboard_tutorial.html
9495
:tags: Interpretability,Getting-Started,TensorBoard
9596

97+
.. customcarditem::
98+
:header: Good usage of `non_blocking` and `pin_memory()` in PyTorch
99+
:card_description: A guide on best practices to copy data from CPU to GPU.
100+
:image: _static/img/pinmem.png
101+
:link: intermediate/pinmem_nonblock.html
102+
:tags: Getting-Started
103+
96104
.. Image/Video
97105
98106
.. customcarditem::
@@ -969,6 +977,7 @@ Additional Resources
969977
beginner/pytorch_with_examples
970978
beginner/nn_tutorial
971979
intermediate/tensorboard_tutorial
980+
intermediate/pinmem_nonblock
972981

973982
.. toctree::
974983
:maxdepth: 2

intermediate_source/dqn_with_rnn_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@
298298
# either by passing a string or an action-spec. This allows us to use
299299
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
300300
#
301-
qval = QValueModule(action_space=env.action_spec)
301+
qval = QValueModule(spec=env.action_spec)
302302

303303
######################################################################
304304
# .. note::

0 commit comments

Comments
 (0)