Skip to content

Commit 75a581c

Browse files
rgommerssoumith
authored andcommitted
Fix some minor issues in Custom C++ and CUDA Extensions (#580)
1 parent 9341570 commit 75a581c

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

advanced_source/cpp_extension.rst

+9-8
Original file line numberDiff line numberDiff line change
@@ -147,23 +147,22 @@ For the "ahead of time" flavor, we build our C++ extension by writing a
147147
``setup.py`` script that uses setuptools to compile our C++ code. For the LLTM, it
148148
looks as simple as this::
149149

150-
from setuptools import setup
151-
from torch.utils.cpp_extension import CppExtension, BuildExtension
150+
from setuptools import setup, Extension
151+
from torch.utils import cpp_extension
152152

153153
setup(name='lltm_cpp',
154-
ext_modules=[CppExtension('lltm', ['lltm.cpp'])],
155-
cmdclass={'build_ext': BuildExtension})
156-
154+
ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
155+
cmdclass={'build_ext': cpp_extension.BuildExtension})
157156

158157
In this code, :class:`CppExtension` is a convenience wrapper around
159158
:class:`setuptools.Extension` that passes the correct include paths and sets
160159
the language of the extension to C++. The equivalent vanilla :mod:`setuptools`
161160
code would simply be::
162161

163-
setuptools.Extension(
162+
Extension(
164163
name='lltm_cpp',
165164
sources=['lltm.cpp'],
166-
include_dirs=torch.utils.cpp_extension.include_paths(),
165+
include_dirs=cpp_extension.include_paths(),
167166
language='c++')
168167

169168
:class:`BuildExtension` performs a number of required configuration steps and
@@ -413,7 +412,7 @@ see::
413412
If we call ``help()`` on the function or module, we can see that its signature
414413
matches our C++ code::
415414

416-
In[4] help(lltm.forward)
415+
In[4] help(lltm_cpp.forward)
417416
forward(...) method of builtins.PyCapsule instance
418417
forward(arg0: torch::Tensor, arg1: torch::Tensor, arg2: torch::Tensor, arg3: torch::Tensor, arg4: torch::Tensor) -> List[torch::Tensor]
419418

@@ -473,6 +472,8 @@ small benchmark to see how much performance we gained from rewriting our op in
473472
C++. We'll run the LLTM forwards and backwards a few times and measure the
474473
duration::
475474

475+
import time
476+
476477
import torch
477478

478479
batch_size = 16

0 commit comments

Comments
 (0)