- https://pytorch.org/tutorials/advanced/cpp_extension.html
- How to use CUDA and C++ files to write customized kernel ops:
How to (high level)
- To create a custom op:
- Write CUDA kernels in .cu files
- Create C++ wrapper functions in .cpp files
- Define the Python interface using pybind11 or the PyTorch C++ extension system
- Use setuptools to compile and link the code
- reating a PyTorch autograd Function class
Python interface (.py)
- can use
from torch.utils.cpp_extension import load_inline
CUDA files (.cu):
- Contains CUDA kernel implementations
- Written in CUDA C++
- Compiled with NVIDIAās CUDA compiler (nvcc)
C++ files (.cpp):
C++ interface
- Contains CPU implementations and CUDA kernel launchers
- e.g.
Python binding
- Can use pybind11 within C++
Header files (.h):
- Contain function declarations and interface definition
How to (low level)
-
C++ extensions come in two flavors: They can be built āahead of timeā with
setuptools
, or ājust in timeā viatorch.utils.cpp_extension.load()
. -
Running example = lltm operation
-
We want to be able to write
import lltm_cpp
in our code
Folder setup and building
- The folder structure is
pytorch/
lltm-extension/
lltm.cpp
setup.py ## if we're using setuptools
- With
setuptools
, this issetup.py
- and you run
python setup.py install
- and you run
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='lltm_cpp',
ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})
- With JIT compiling
from torch.utils.cpp_extension import load
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])
Writing the C++ op
- Letās say we need the derivative of the sigmoid for the backward pass
#include <torch/extension.h>
#include <iostream>
torch::Tensor d_sigmoid(torch::Tensor z) {
auto s = torch::sigmoid(z);
return (1 - s) * s;
}
<torch/extension.h>
is the one-stop header to include all the necessary PyTorch bits to write C++ extensions. It includes:- The ATen library, which is our primary API for tensor computation,
- pybind11, which is how we create Python bindings for our C++ code,
- Headers that manage the details of interaction between ATen and pybind11.
Exposing the functions
- Once you have your operation written in C++ and ATen, you can use pybind11 to bind your C++ functions or classes into Python in the C++ files
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
}
Mixed C++/CUDA
-
https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-a-mixed-c-cuda-extension
-
The general strategy for writing a CUDA extension is to
-
first write a C++ file which defines the functions that will be called from Python, and binds those functions to Python with pybind11.
- Furthermore, this file will also declare functions that are defined in CUDA (
.cu
) files. - The C++ functions will then do some checks and ultimately forward its calls to the CUDA functions.
- Furthermore, this file will also declare functions that are defined in CUDA (
-
In the CUDA files, we write our actual CUDA kernels and the interfaces that do (the kernel launches). .
-
The
cpp_extension
package will then take care of compiling the C++ sources with a C++ compiler likegcc
and the CUDA sources with NVIDIAāsnvcc
compiler. This ensures that each compiler takes care of files it knows best to compile.
-
Defining the CUDA file
.cu
extension- NVCC can reasonably compile C++11, thus we still have ATen and the C++ standard library available to us (but not
torch.h
).
- NVCC can reasonably compile C++11, thus we still have ATen and the C++ standard library available to us (but not
- Note that
setuptools
cannot handle files with the same name but different extensions, so if you use thesetup.py
method instead of the JIT method, you must give your CUDA file a different name than your C++ file
Integrating into Pytorch
- https://pytorch.org/tutorials/advanced/cpp_extension.html#integrating-a-c-cuda-operation-with-pytorch
- Just like previously, can use
setuptools
or JIT compiling, the args are slightly different