Cool Projects

HummingBird: A Fast IO Aware Attention Kernel For Small Sequences

12/30/24

tl;dr I made an attention kernel optimized for a small number of tokens that runs almost thrice as fast (wall clock time) as flash attention on a 4090

Results

First iterations

Before beginning sophomore fall semester, I had spent the majority of August evenings thinking about how to reduce the time complexity of the attention mechanism. I had started to think about methods of computing attention through cascading layers (basically a two-step sliding window attention).

Results looked promising. My models were training and the losses were going down. However, I was gluing my modifications on top of torch's scaled_dot_product_attention so I wasn't getting the desired performance increase. Therefore, I decided to take a quick (2 month) detour to learn CUDA.

I decided that I my first kernel would implement variable sliding window attention. I spent a weekend on that, and once it was done, I benchmarked it and, to my surpise, it was faster than I expected! I made the sliding window as big as the number of tokens (equivalent to vanilla attention), and astonishingly it was also faster than flash attention.

Of course, this was a fallacy. I was benchmarking the kernel launch time (CPU + CUDA time) instead of pure CUDA time. To the surpise of no one, my triple for loop matrix multiplication was actually 100x slower...

But I persisted! I had already told my advisor that I was pivoting my research to making CUDA kernels so I had to deliver. Two hectic months of classes passed and finally I had real results. Below is an explanation of how it works.

Technical Details

First of all, how to we make programs run faster? One way is by recuding the memory bottleneck. This is done in two ways: through kernel fusion and by keeping as many of the intermediate computations in shared memory. Therefore, the most important consideration of my implementation was the architecture I was dealing with, as this would determine the maximum size of the matrices that I could keep in shared memory.

4090s have 128 KB of shared memory per streaming multiprocessor (sort of like CPU cores in the sense that they have hardware threads and there are many of them). Some of this memory is usually used as an L1 cache but we can tell the GPU that we wish to use all of it using cudaFuncAttributePreferredSharedMemoryCarveout. Now we have the full 128 KB per SM to ourselves, which we run 4 blocks on, leaving us with 32 KB per block.

Below are the results for the time taken on 32 tokens with a head dimension of 64 on a 4090 using torch.profiler.ProfilerActivity. The results for 48 tokens and a head dimension of 48 follow a similar pattern.

batch_size num_heads hummingbird torch flash_attn
1 1 2.720us 7.169us 6.624us
1 32 2.688us 7.072us 6.784us
1 128 3.168us 7.744us 7.232us
10 128 14.624us 43.040us 38.720us
100 128 130.271us 408.545us 377.185us
1000 128 1.266ms 4.039ms 4.138ms

# Sample Python Code
def greet(name):
    return f"Hello, {name}!"
print(greet("World"))
                

So... Does it scale?

Acknowledgements

I would like to greatly thank Professor Emily Wenger for advising me through this research project and believing in me while I struggled to learn CUDA.