0x511 Computing
This note is about kernel implementation in a single-device
Check this lecture series for Heterogeneous computing (mainly GPU)
1. CUDA Programming
This section follows this book (Kirk and Wen-Mei, 2016)1
CUDA programming is an instance of SPMD (single program multiple data). Note that SPMD is a higher level concept of SIMD and is not the same as SIMD. In a SPMD system, processing units do not need to be executing the same instruction at the same time while processing units in SIMD are executing same instruction at any time in strict lockstep
CUDA's SPMD is called SIMT
Driver API
Driver API is handle-based, imperative API implemented in cuda.so
. It can be used, for example, loading an existing PTX kernel from file
// Get handle for device 0
CUdevice cuDevice;
cuDeviceGet(&cuDevice, 0);
// Create context
CUcontext cuContext;
cuCtxCreate(&cuContext, 0, cuDevice);
// Create module from binary file
CUmodule cuModule;
cuModuleLoad(&cuModule, "VecAdd.ptx");
// Allocate vectors in device memory
CUdeviceptr d_A;
cuMemAlloc(&d_A, size);
...
// Copy vectors from host memory to device memory
cuMemcpyHtoD(d_A, h_A, size);
// Get function handle from module
CUfunction vecAdd;
cuModuleGetFunction(&vecAdd, cuModule, "VecAdd");
// Invoke kernel
int threadsPerBlock = 256;
int blocksPerGrid =
(N + threadsPerBlock - 1) / threadsPerBlock;
void* args[] = { &d_A, &d_B, &d_C, &N };
cuLaunchKernel(vecAdd,
blocksPerGrid, 1, 1, threadsPerBlock, 1, 1,
0, 0, args, 0);
2. GEMM
Good to check cutlass code
See this post
3. Convolution
3.1. Winograd
Fast Algorithms for Convolutional Neural Networks
4. Attention
4.1. FlashAttention
reduce communication cost between SRAM and HBM by tiling + rematerilization
4.2. Flash Decoding
flash decoding splits over the sequence dim and applies FlashAttention at two level
5. Communications
This section is about communication primitives and its implementations.
5.1. ReduceScatter
5.2. AllReduce
There are few methods to implement allreduce. For example, it can be implemented with ReduceScatter + AllGather.
Ring-based algorithm are implemented in Horovod and Baidu Allreduce. See Baidu's simple allreduce's implementation using MPI_Irecv
and MPI_Send
. An advanced ring-based approach is 2d ring algorithm.
Double-binary Tree are NCCL implementation. See this blog
6. Libs
6.1. Cudnn
There are two APIs right now:
- imperative Legacy API
- declarative Graph API
6.1.1. Graph API
7. Reference
-
David B Kirk and W Hwu Wen-Mei. 2016. Programming massively parallel processors: A hands-on approach. Morgan kaufmann. ↩