Bitter Lesson for faster kernels?
Before you read this, you may think something along the lines of:
"What in God's holy name are you blathering about?"
or maybe you feel like you've got no frame of reference here.
or that you're out of your element.
Maybe you don't know the preferred nomenclature.
I'll leave some links to some really cool stuff that you may find useful:
- GPU Glossary by Modal
- OG matmul optimization blog
- Another really nice matmul optimization blog
- PTX instruction set docs
The universal approximation theorems state something along the lines of: any continuous function between two real vector spaces,
It’s time to replace PyTorch with one GIGANTIC neural net and start training your LLMs that way! At this point, you realize I’m being goofy. The truth of the situation is obvious: that gigantic neural net, which could essentially learn the same function as an LLM, is probably EXTREMELY GIGANTIC and probably requires an EXTREMELY LARGE AMOUNT OF DATA. So, generality is lovely, but to even expose a tractable search in the weight space to learn anything meaningful from text data, the specific architecture of the transformer significantly mogs (in the parlance of our times) a feedforward neural net.
Next, a CPU and GPU are both Turing complete. Time to run all our deep learning on CPUs! Well, no—the specific architecture of the GPU yet again outperforms the CPU for deep learning applications.
The Evolution of GPUs and Their Instruction Set
Let’s just look at NVIDIA and their Tensor Cores. Initially, you had thread-level load_matrix
, MMA
(matrix multiply and accumulate), and store_matrix
, where you had to orchestrate the loads/stores and compute at the thread level (load was often from shared memory to registers), and these instructions imposed certain rules and restrictions on access patterns and so on.
Then, you had warp-level and warp-group-level analogs of these instructions. Now, today, in the 5th generation Tensor Cores, you have instructions that operate at the Cooperative Thread Array (block, if you prefer, but some caveats there) level, and they have a wild amount of flexibility—but each instruction also comes with a wild amount of specification on legal ways of using it. Even the somewhat dear axiom of “memory is just 1D addresses at the end of the day” does not hold with the new Tensor Memory, which is organized as a 2D matrix.
So, a general model for a GPU and the full instruction set obviously can’t stand the test of time and expose any tractable search. A lot of deep learning ops are here to stay (matmul, haha), and SIMT is really good for them. But we can’t make general models for the whole GPU and instruction set without sacrificing robustness across changes and search tractability.
Tiny SIMT Mathematical RISC Machines
I mean, zooming out a lot and waving hands, a GPU kernel is a series of: take your “compute object” (input and output tensors, and the thing you want to do with them), take your “work object” (nested CTA cluster, CTA, warp, thread, or block, warp, thread), and assign partitions of “work objects” to partitions of “compute objects” cleverly.
Suppose there is a universal mathematical language (say, category theory + additional structure), and suppose you want to do a matmul. What you need is a tiny SIMT matmul machine model with:
- A model of the few (≤ O(10)) instructions you will use, along with their constraints and so on. As an example:
ld_gmem_to_smem
,ld_smem_to_tmem
,barriers
(things that can control synchronization and scheduling and stuff),mma_on_tensor_core
,st_tmem_to_smem
,st_smem_to_gmem
(this is hypothetical). - A necessary and sufficient model of the GPU for this kernel. As an example model: GMEM, L2 cache, SM, SMEM, L1 cache, TMEM (tensor memory), Tensor Core.
- Models for guiding search toward speed: GMEM coalescing, warp divergence, SMEM bank conflicts, cache locality, occupancy at various stages, and so on.
- Now, you can define search spaces over access patterns for each instruction, memory layouts, sizes, and so on. And define constraints to throw out ridiculous stuff and prune before actual search.
Firstly, we need to get some math that is general enough to model these tiny RISC SIMT machines that expose tractable search. Next, build a programming language off this math model. Like how one builds a Python class or a PyTorch model, one should be able to build operation-specific tiny SIMT RISC machines, and then for search, maybe some combination of classical + RL and so on.
I did try modeling a really simple element-wise kernel machine for a talk I gave, and it seems reasonable in the sense that it isn't crazy hard to come up with a model (albeit probably a terrible one for our purposes) for something like this with category theory. I’ll put the gory details some other time; I’m feeling lazy.
If AlphaFold can use RL to search for better matmul algorithms, maybe it can find something interesting when searching for better orchestrations of the same matmul algorithm on some tiny SIMT RISC machine model (I'll just call them tinyGPUs from now on).
To build mathematical machinery that can house any tinyGPUs, one has to try to build machinery that can house tinyGPUs for current kernels, on current hardware (maybe even previous generations), and in order to do that, one has to build lots of fast(er) kernels by hand and understand them.
I've written a decent amount kernels, but haven't really touched crazy peak performance yet.
That's where I'll start. I'll make a folder here, and probably (eventually) a Git repo going through each PTX Tensor Core Gen 5 instruction, breaking it down, and understanding it.