HummingBird: A Fast IO Aware Attention Kernel For Small Sequences
12/30/24
tl;dr I made an attention kernel optimized for a small number of tokens that runs almost thrice as fast (wall clock time) as flash attention on a 4090
Results
- There are two kernels which I have made: one for sequences of <= 32 tokens with a head dimension of 64, and one for <= 48 tokens with a head dimenision of 32. Both kernels are made for float16 precision.
- The 32 token kernel runs in 790μs for a batch size of 1000 and a model dimension of 5120 = 64 * 80 (head_dim * num_heads). Flash attention takes 2.26ms. This is almost a 3x speed-up on wall clock (CUDA) time.
- The 48 token kernel runs in 710μs for the same batch size and model dimension. Flash attention takes 1.44ms. This is slightly over a 2x speed up.
- These speed ups, both in model dimension and batch size, scale up to any size (HBM permitting) for those number of tokens and head dimension, NOT for other sizes.
First iterations
Before beginning sophomore fall semester, I had spent the majority of August evenings thinking about how to reduce the time complexity of the attention mechanism. I had started to think about methods of computing attention through cascading layers (basically a two-step sliding window attention).
Results looked promising. My models were training and the losses were going down. However, I was gluing my modifications on top of torch's scaled_dot_product_attention so I wasn't getting the desired performance increase. Therefore, I decided to take a quick (2 month) detour to learn CUDA.
I decided that I my first kernel would implement variable sliding window attention. I spent a weekend on that, and once it was done, I benchmarked it and, to my surpise, it was faster than I expected! I made the sliding window as big as the number of tokens (equivalent to vanilla attention), and astonishingly it was also faster than flash attention.
Of course, this was a fallacy. I was benchmarking the kernel launch time (CPU + CUDA time) instead of pure CUDA time. To the surpise of no one, my triple for loop matrix multiplication was actually 100x slower...
But I persisted! I had already told my advisor that I was pivoting my research to making CUDA kernels so I had to deliver. Two hectic months of classes passed and finally I had real results. Below is an explanation of how it works.
Technical Details
First of all, how to we make programs run faster? One way is by recuding the memory bottleneck. This is done in two ways: through kernel fusion and by keeping as many of the intermediate computations in shared memory. Therefore, the most important consideration of my implementation was the architecture I was dealing with, as this would determine the maximum size of the matrices that I could keep in shared memory.
4090s have 128 KB of shared memory per streaming
multiprocessor (sort of like CPU cores in the sense that
they have hardware threads and there are many of them). Some
of this memory is usually used as an L1 cache but we can
tell the GPU that we wish to use all of it using
cudaFuncAttributePreferredSharedMemoryCarveout
.
Now we have the full 128 KB per SM to ourselves, which we
run 4 blocks on, leaving us with 32 KB per block.
Below are the results for the time taken on 32 tokens with a
head dimension of 64 on a 4090 using
torch.profiler.ProfilerActivity
. The results
for 48 tokens and a head dimension of 48 follow a similar
pattern.
batch_size | num_heads | hummingbird | torch | flash_attn |
1 | 1 | 2.720us | 7.169us | 6.624us |
1 | 32 | 2.688us | 7.072us | 6.784us |
1 | 128 | 3.168us | 7.744us | 7.232us |
10 | 128 | 14.624us | 43.040us | 38.720us |
100 | 128 | 130.271us | 408.545us | 377.185us |
1000 | 128 | 1.266ms | 4.039ms | 4.138ms |
# Sample Python Code
def greet(name):
return f"Hello, {name}!"
print(greet("World"))
So... Does it scale?
Acknowledgements
I would like to greatly thank Professor Emily Wenger for advising me through this research project and believing in me while I struggled to learn CUDA.