Skip to content

Commit 68a1f59

Browse files
goldsboroughsoumith
authored andcommitted
Some corrections to the C++ extensions tutorial (#230)
1 parent 47f3385 commit 68a1f59

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

advanced_source/cpp_extension.rst

+8-7
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ look something like this::
6161
# 3 * state_size for input gate, output gate and candidate cell gate.
6262
# input_features + state_size because we will multiply with [input, h].
6363
self.weights = torch.nn.Parameter(
64-
torch.Tensor(3 * state_size, input_features + state_size))
65-
self.bias = torch.nn.Parameter(torch.Tensor(3 * state_size))
64+
torch.empty(3 * state_size, input_features + state_size))
65+
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
6666
self.reset_parameters()
6767

6868
def reset_parameters(self):
@@ -389,10 +389,9 @@ should look something like this::
389389

390390
A small note on compilers: Due to ABI versioning issues, the compiler you use to
391391
build your C++ extension must be *ABI-compatible* with the compiler PyTorch was
392-
built with. In practice, this means that you must use GCC version 4.9 and above.
392+
built with. In practice, this means that you must use GCC version 4.9 and above on Linux.
393393
For Ubuntu 16.04 and other more-recent Linux distributions, this should be the
394-
default compiler already. On MacOS, you will have to download GCC (e.g. `brew
395-
install gcc` will give you GCC 7 at the time of this writing). In the worst
394+
default compiler already. On MacOS, you must use clang (which does not have any ABI versioning issues). In the worst
396395
case, you can build PyTorch from source with your compiler and then build the
397396
extension with that same compiler.
398397

@@ -449,8 +448,8 @@ class citizens of PyTorch::
449448
self.input_features = input_features
450449
self.state_size = state_size
451450
self.weights = torch.nn.Parameter(
452-
torch.Tensor(3 * state_size, input_features + state_size))
453-
self.bias = torch.nn.Parameter(torch.Tensor(3 * state_size))
451+
torch.empty(3 * state_size, input_features + state_size))
452+
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
454453
self.reset_parameters()
455454

456455
def reset_parameters(self):
@@ -543,10 +542,12 @@ memory with ``.cuda()`` from Python::
543542
for _ in range(100000):
544543
start = time.time()
545544
new_h, new_C = rnn(X, (h, C))
545+
torch.cuda.synchronize()
546546
forward += time.time() - start
547547

548548
start = time.time()
549549
(new_h.sum() + new_C.sum()).backward()
550+
torch.cuda.synchronize()
550551
backward += time.time() - start
551552

552553
print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))

0 commit comments

Comments
 (0)