LLM Inference - the state and outlook

This post is a draft for a lecture I'll be giving later this year at National Sun-Yat Sen University, Kaohsiung, Taiwan. Where I'm a guest lecturer, at a course by Prof. Chang. Whom is my advisor during my college years. Making a slide deck is hard without a script, hence this post where I blabber on about the topic and collect my thoughts. I'll share the slides once I've gave the talk.

At least parts of it. The field is moving too fast that I bet my legs that something big will happen between the time I write this and the time I give the talk.

State of the union

Large language Models have bring a grate boom to the entire world. Hard problems we thought were impossible are now damn easy. Puff, natural-language understanding, translation, summarization can all be solved with a single LLM (though not the most efficient way). However, we run into problems with using these LLMs. Hardware and to run these models are expensive. As of writing, the most popular way to run LLMs is on GPUs. But even with a NVIDIA RTX4090. We are looking at 100 tokens/s running quantized LLaMA2-Chat 13B at 350W. This is good enough for a single user, but not for a large scale application. We can't build large scale services with this kind of power consumption and user-to-machine ratio. Models needs to be more efficient. And new hardware to reduce the power consumption. This is a matter of needs, not wants. Less we will literally need nuclear power plants in each data center. Not saying that's not cool.

On the model side, the bottleneck on current hardware are the dot-product attention. In the process, it produces a [context-length, context-length] matrix. Then that matrix is multiplied by the attention matrix to produce the output. This operation scales with the square of the number of tokens, leading to two issues - you need a lot of memory to store the matrix, and requires quadratic memory and compute as context scales. Fortunately both short and long terms solutions have emerged. Optimizations to workaround the quadratic scaling like flash-attention have greatly improved the context window of LLMs. And new architectures that forgo the dot-product attention entirely, like Mamba and RWKV, have shown that we can build language models comparable with Transformers without attention. These models in theory have infinite context windows (though in practice not). And due to not generating a large matrix as a intermediate step, can be effectively batched for efficient inference on current hardware.

Talking numbers, the latest RWKVv5 FP16 on a RTX 3090 can do 8 token/s at batch size of 160. Compare this to LLaMA using a H100, 8bit weights, getting 300 token/s at batch size of 4.

I believe the tread of compute efficient LLMs will continue. We will see more models that can be batched and services that take advantage of this. Latency is an issue. But jobs like email responding, automated document summarization, etc.. don't need realtime.

Source of the pain

The last AI boom, from 2012 ~ 2018, was driven by vision models. With a few exceptions, they are convolutional networks. Which can be easily chunked into small pieces and run in parallel. Can be easily parallelized and memory access is almost linear if you optimize well. Heck the reason CNNs are created in the first place is because fully connected layers are large and difficult to run. This, is not the case with LLMs. LLMs runs almost exclusively on fully connected layers. Our current hardware, both on the edge and in data centers, are not designed to handle this. Worse, LLMs are so large that no one is running them at a large batch size to maintain a reasonable latency - which is how accelerators are designed to be used. Leading to a lot of accelerators ending up being decelerators.

Current LLMs, LLaMA 1/2, RWKV, etc.. also doesn't work well with fixed point math. Usually vision accelerators quantizes the weights and activations to 8bit or 4bit to save on bandwidth and able to squeeze more MACs into the same area. That doesn't work with LLMs. Naive quantization makes LLMs, even just to the activations which we usually think is not as critical, makes LLMs dumb. Frameworks like GGML does dynamic bit depth quantization as a workaround. Even so, some weights are still prefered to be kept as FP16 in order to maintain quality. And anything below 4bit starts to perform horribly.

Evey symptiom points back to two fundamental problems with computers that we computer scientists has been sweeping under the rug for the past 30 years. Memory bandwidth.

The speed of how fast memory is grows much slower then the speed of how fast we can do math. We invented CPU caches and later on prefetching to work around this. LLMs single handedly brought this problem back to the forefront. There is no chance we put the entire LLM into SRAM (well, someone actually did it, will be pointed later). And so we are left with the yucky problem of being bandwidth bounded while the demand of LLMs are still growing.

Sustainability concerns

It should be obvious that AI is and will use a lot of power. Keeping the trend while we have a climate crisis would be an interesting choose. That's while we are already feeling the effects of it.

Fortunately this need not be the case. Just this time curbing the power needs will be a lot harder. Remember the last AI boom, where we are obsessed with YOLO and GANs? Later on a lot of AI accelerator hardware shows up and drastically reduces the power needs of these models. It turns out YOLO wasn't that useful and tech has improved so much that YOLO can even run on higher-end MCUs. The same thing can happen with LLMs. There's the same two-part solution - we need ways to efficiently run LLMs. And LLMs has to turn out to not be as useful as people think.

Building hardware that runs LLMs fast is the relatively easy part, even though this post is all about how hard is it and how we can solve it. The problem is if the Efficiency Paradox shows up. That is, if new hardware brings LLM inference cost down 10x. Would the demand increase more then 10x causing more resource use? - The first part can and will get solved. But the second part is a non-trivial social problem.

The second part, I do not know. But judging from my experiments with LLaMA 2, RWKVv4/5 and Gemma. I hardly think there is enough use cases for it. There always needs to be a human backup in cases the models screwed up big timee. And the models are not that good at even understanding the context of the conversation. Future models might be better. But to does not stop there. Larger models will also become cheaper to run too. Stuff like Gemini Pro and GPT-4 will become cheaper to run (though, they are subsidized by VC funding now so don't know how much chepaer they can get). And those models have actual use cases even now. I've enjoyed AI assisted search and ChatGPT answering my questions instead of some a**hole on StackOverflow. Will it increase so much so that the total energy use will increase? I think it's 50/50. One hand, people will start using ChatGPT to replace customer service and automatically sort through their tasks today.However, on the other hand, these are not tasks that can't be done with a smaller LLaMA, if you try hard enough.

I might write a post about this sometime later. After I think more about it.

Network level solutions

There has been several directions that researchers been trying solving the above issues. I'll go through the ones I know of and fells that it has potential.

Extreme quantizations of weights

Early this year, someone figured out that you can quantize the weights of a LLM to a trite (3 values, -1, 0, 1) and still get a good model. This is potentially a huge deal. It means that we can get rid of FP multiplication and work around memory bandwidth limitations. The paper estimates 72x lower compute energy and I estimates a 2.53x less memory compared to GGML Q4 format. Further improvements can be achieved by trivially pruning the model. Though the paper is not well written and the model hasn't been publicly released. It'll take time to verify the results. If true, this will become one of the important pillars of reducing the power needs of LLMs.

Attention free language models

Models without the attention mechanism have started to appear. Mamba and RWKV being the most popular in the category. Due to the fact that they don't generate a large matrix as a intermediate step, they are efficient and easy to infer. While avoiding the need for KV Cache commonly employed for transformers.

RWKV has another major advantage - at batch size = 1, it is 99% GEMV. This is a significant advantage for the edge hardware designers. You only need to implement one operation, and you don't care about batch size. Reducing the design space drastically.

Model Pruning

Another direction that can work, but is under explored, is pruning. Pruning is the process of removing weights that are not important. This can be done by looking at the weights and activation. Decide which weights can be removed. There are complications in the inference process. First and foremost, decoding pruned models efficiently can be complicated. And the last thing you want to do is to fight your branch predictor during inference - it'll immediately kill your performance.

I suspect the under expiration is caused by lacking hardware that can efficiently utilize pruned models. CPUs does not pack enough FLOPS to infer LLMs, GPUs can't do conditionals fast enough to do sparse models piratically. Custom hardware has to be built to take advantage of pruned models. Thus ending up with a chicken and egg problem.

Novel hardware

On the hardware front. I think it is fair to say that GPUs are not the most efficient thing to use. Even though they are the most popular (and accessible) way to run LLMs. GPUs simply have too much silicon that is not helpful for LLMs, but still consumes power. Like the L2 cache, the texture units, etc.. And GPUs are too fast for it's own good. GDDR6X is not even fast enough to feed the GPU.

SRAM! more SRAM!

Since DRAM bandwidth is THE issue. Some companies had the bright idea to not have DRAM all together. SRAM FTW.

For instance, Groq made a lot of noise this February. Their LPU has 230MB (bytes! not bits) of SRAM to hold the models. Usually this is a really bad ideas as you can't fit a LLM into that amount of memory. And accessing the weights via the PCIe bus is not fast, not even it goes through CXL. But Groq has a trick up their sleeve - Simply have a pile of LPUs and have them talk to each other. Now you can split the layers onto different LPUs and have them talk to each other. Furthermore, you can now pipeline the inference. They achieved more then 700 tok/s throughput, for a single user for the popular Gemma 7B model by Google. Counting pipelining, we are looking at tens of thousands of aggregate tokens/s.

The downside is also obvious. SRAM, thus the LPU is expensive. And unlike a GPU, where the minimum unit is a single GPU, you need enough LPUs to hold symptomthe model. We are looking at a few dozens of cards to run a single 7B model.

Superscaling systolic arrays

Another approach to the DRAM bandwidth is almost like what Groq does. But instead of relying on expensive SRAM exclusively. We first make a critical observation - even in current GPUs and accelerators, we attempt our best to fuse operations together to reduce the amount of data that needs to written then read from DRAM. We can push this much further with a systolic array. It is almost impossible to merge 2 convolution operations into 1 kernel on a GPU. But it is almost trivial for a systolic array. Make space for the second convolution and you are done.

It's not that easy but you get the idea. Scaling down compute has a lower impact compare to the saved memory bandwidth. This, I think, is the idea behind Tenstorrent's architecture. Their processor is a 12x12 systolic array. Each "core" contains 5 RISC-V core, a tensor engine and a very wide SIMD unit. The lack of a flat memory space makes the architecture hard to program for. But also freaking efficient. Thier Grayskull e75 card can achieve 220 TFLOPS @ BFP2 at just 75W. Approximately twice what the RTX4090 can do, but also at a fraction of accuracy.

That not where Tentorrent's fun ends. Their next generation card, Wormhole, contains 6 100GbE interfaces. Enabling low latency communication between cards. Enabling the same scaling and piplining trick as Groq. But with a much lower cost. And can achieve HBM level speed by stringing multiple cards together.

Fixed function accelerators

This is really not a "futrue work" but a trend I like to see continued. Even with systolic arrays, you are still using quite some power to move data around. If, and a big if, you can figure out what are the common operations and their fusion, you can avoid that movement. Cutting down power and area, or spend them on more memory interfaces.

Rockchip has been building fixed function accelerators for AI for a while now. Their RK3588 is their most powerful chip yet. It contains a NPU that can do 12 TOPS INT4. And is extremely power efficent (when the NPU supports all operations you need) as the control logic is mostly a state machine and a few registers for stuff like data ordering, activation, etc..

I believe is the kind of hardware that can run LLMs on the cheap. And the end-all be-all of LLM inference. In fact, I believe the end state would be models sharded across multiple fixed function accelerators. Each accelerator is only responsible for parts of a model, with high bandwidth memory (not necessarily HBM, could also be wide DDR) to stream in the weights. The actual area used for compute could be quite small. And most power would be used on the memory interfaces.

Lower and sub-byte percision math

It's public knowledge both NVIDIA and AMD is using their own FP8 format to further reduce the power needed for LLMs. FP8 is not one signle format, There's too little bits to keep LLMs happy. But two, with different bits of mantissa and exponent. As FP multiplication's power scales quadratically with the number of bits, this reduces the power needed a lot.

The FP8 formats supported by NVIDIA's GPUs
Image: The FP8 formats supported by NVIDIA's GPUs

The question is, can we push this further? GGML has demonstrated we can go down to 2 bits and still keep the model sort of coherent. Instead of storing the activations as FP16, can we store them as FP8? Or even break the sub-byte barrier? Maybe even give up byte alignment and store the activations as 6 bits?

Author's profile. Photo taken in VRChat by my friend Tast+
Martin Chang
Systems software, HPC, GPGPU and AI. I mostly write stupid C++ code. Sometimes does AI research. Chronic VRChat addict

I run TLGS, a major search engine on Gemini. Used by Buran by default.

  • marty1885 \at protonmail.com
  • Matrix: @clehaxze:matrix.clehaxze.tw
  • Jami: a72b62ac04a958ca57739247aa1ed4fe0d11d2df