Building new GGML backends for novel accelerators, how, challenge and opportunities (FOSDEM 2025 draft)
This is a draft for a talk that I'll be giving at FOSDEM 2025 under the same title. As usual, I like to actually write out the full possible extent of the talk so I have a good idea of what's on the table. Since I'm already writing, why not share it? I wish I can write and spell everything out. But that would just be the source code itself.
This post and thus the talk assumes readers are reasonably familiar around both computer architecture and deep learning. I will try to keep everything factual. But some details really gets in the way and therefore will be skipped.
Bear with me.
Disclosure: The following nor my GGML backend is work paid by Tenstorrent. I make them out of my own interests. However, I do have a Wormhole N300 and a QuietBox sent by Tenstorrent free of charge and no strings attached. Personally I count them as tools to accelerate my efforts (which I'm doing already), they think is worth their monetary cost, instead of being any form of payment.
GGML has emerged as one of the leading framework for efficient inference, particularly for large language modes. GGML is engineered with performance in mind and with great quantization support. More importantly. Its flexibility and growing community support make it an ideal platform for exploring new hardware acceleration options. Already some rare hardware (mostly China's alternative to western chips) have made it into GGML.
My journey with GGML starts back in end of 2022 when LLMs are new, hacking GPT-3 is damn easy and RK3588 was new and exciting. Rockchip finally got around and made their matrix multiplication API working. And I was starting to get fairly ecoconscious. I was thinking, LLMs are power hungry running on GPUs. And I want a local AI assistant without a literal space heater in my room. Which in Taiwan can be 32°C indoors during summer! So I started looking into utilizing the RK3588 NPU in llama.cpp. It sorta works. But the NPU's design is really not suitable for LLMs and has miserable performance for such use case (worse then CPU!).
Discussions about llama.cpp + RK3588 NPU can be found in the following link. I'm happy my work sparked some interest in the NPU and work to reverse engineer the thing.
I decided to pivot. There's not many options at the time. Non traditional hardware (CPU/GPU) options includes Intel's NPU (which already has OpenVINO support), AMD XDNA (the driver was an entire kernel fork), SiFive's X280 (not going to be commercially available until at least end of 2024), Vivante's VIP9000 (need NDA, basically GPU with a convolution engine bolted on) and Tenstorrent's cards. It just happens that Tenstorrent is selling their cards (vetting required at the time, though lifted shortly after) and I got my hands on one an Grayskull e75 devkit.
Tenstorrent hardware and SDK
Tenstorrent doesn't make GPUs. Their processor is a manycore architecture each containing RISC-V CPUs with tensor engines attached, called a Tensix core. A Tensix core contains 5 very small RISC-V CPU referred as baby cores - your textbook standard single issue 5 stage pipelined CPU (they take up very little chip area!). These baby cores are really slow and thus not performing the bulk of AI computation. Instead, their purpose is to issue commands to the NoC and matrix/vector engine to so they can carry out the real work.
The same baby RISC-V cores are connected to Ethernet and DRAM controllers for data prefectch and external communication. These different cores are arranged into a grid and connected together via a NoC. Their plan is to exploit this physical layout to reduce power and improve performance. Unlike GPUs, which is presented as a uniform pile of cores, Tenstorrent does not hide the physical location to the program. Developers can then make data flow only in nearby cores or in a straight line to maximize aggregate data throughput (not interfering with traffic from other cores). Also unlike GPUs, there's no global data cache nor a unified address space, which gets congested with frequent access. Instead, cores issue DMA requests to move data from/to their local SRAM to DRAM or another core's SRAM (SRAM is referred as L1 in documents).
This grid structure also scales well naturally. Connecting multiple chips imply gives you are larger grid to play with.
At the time, Tenstorrent is advertising 2 SDKs. tt-BUDA and tt-Metalium. BUDA is high level, where their compiler looks at a model (Torch, ONNX, etc..) does some magic and produce binaries that does inference on the card. While Metalium is sort of like OpenCL, the API looks alike and the fact that it gives you control of the hardware. The programming model is however quite different. So much so no OpenCL code can be ported to Metalium without a complete rewrite. This is where TTNN comes in. TTNN is both the operator library and the tensor library built on top on Metalium. Providing a numpy/PyTorch like API. Rebuilding a tensor library is not duplicated work nor bad engineering decision. Tenstorrent processors uses its own tiled layout for tensors as they are more preferrable for hardware designers, reducing power and improves efficency. However is not compatible with the row major format that every CPU and GPU tensor uses. (Tiled layout are used my many ASICs, even Rockchip and NVDLA. It's not Tenstorrent being unique and Jim Keller bad at designing).
Won't talk about BUDA since it has no place in GGML and it's being superseded by the new MLIR approach. TTNN and Metalium it is.
The following is an example TTNN program and should be self explanatory. It creates a bfloat16 tensor then multiplies it by 2. Basally PyTorch or NumPy besides a slightly different syntax. The curious function being tilize_with_zero_padding
, which converts from the standard row major format to Tenstorrent's tiled layout. Only after converting to tiles can multiplcations be performed.
int main()
{
auto device = &ttnn::open_device(0);
auto a = ttnn::arange(0, 100, 1, tt::tt_metal::DataType::BFLOAT16);
a = ttnn::tilize_with_zero_padding(a.to(device));
auto res = ttnn::multiply(a, 2.f);
std::cout << res.write_to_string() << std::endl;
}
bfloat16 is the most commonly used floating point format on Tenstorrent processors. Beyond bfloat16, the chips also supports custom block floating point (I'll just call them Tenstorrent quantized types from this point on, as that's what they really are) called BFLOAT8_B and BFLOAT4_B. As the name suggests, these formats shares exponents as each mantissa takes up 8 or 4 bits. These types are nativlly supported by the hardware and thus does not encour cost to dequantize. As oppose to GGML needing software to read quantized data, leading to power spent. Otherwise they work on the same principle. Their 2nd generation (Wormhole) or later processors also supports the standard IEEE 754 32bit floating point, as well as integer types like int32 and int16.
GGML backend
Now the juicy part. Bridge the gaping hole between GGML and TTNN. Damn they expect things to be so different. How do we get from the following to instructing TTNN to perform tilization, carry out computation, then detilize back to what GGML can use? Furthermore, what kinds of hack do we need to reverse the level of abstraction differnces?
ggml_tensor* a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2048, 64);
ggml_tensor* b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2048);
return ggml_add(ctx, a, b);
Backends in GGML
We gotta register the new backend so GGML knows it exists. The backend architecture has evolved a lot between when I started and now. And is bounded to evolved further. As of now, GGML has both the notion of a backend (software) and a backend device (hardware). On startup, in ggml/src/ggml-backend-reg.cpp
GGML constructs a ggml_backend_registry
which backend developers adds to it's constructor to register backends. - It's a pile of ifdefs and could be done with some static initialization. But that does not guarantee ordering. What works works.
struct ggml_backend_registry {
std::vector<ggml_backend_reg_entry> backends;
std::vector<ggml_backend_dev_t> devices;
ggml_backend_registry() {
#ifdef GGML_USE_CUDA
register_backend(ggml_backend_cuda_reg());
#endif
#ifdef GGML_USE_METAL
register_backend(ggml_backend_metal_reg());
#endif
#ifdef GGML_USE_SYCL
register_backend(ggml_backend_sycl_reg());
#endif
#ifdef GGML_USE_VULKAN
register_backend(ggml_backend_vk_reg());
#endif
#ifdef GGML_USE_OPENCL
register_backend(ggml_backend_opencl_reg());
#endif
#ifdef GGML_USE_CANN
register_backend(ggml_backend_cann_reg());
#endif
#ifdef GGML_USE_BLAS
register_backend(ggml_backend_blas_reg());
#endif
#ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg());
#endif
#ifdef GGML_USE_KOMPUTE
register_backend(ggml_backend_kompute_reg());
#endif
#ifdef GGML_USE_METALIUM
register_backend(ggml_backend_metalium_reg());
#endif
#ifdef GGML_USE_CPU
register_backend(ggml_backend_cpu_reg());
#endif
}
...
};
Each ggml_backend_xxx_reg()
function returns a pointer to a backend descriptor. This is a common pattern in GGML. A descriptor conatians at least 2 items. 1. A data pointer that the backend developer is responsible to fill and can point to an arbitrary struct. And 2. what essentially is a vtable pointing to different methods. Some methods are optional and can be NULL (read the GGML header to find out), while some are not and will crash if you set it to NULL. Take what's in the Metalium backend for example:
static const ggml_backend_reg_i ggml_backend_metalium_reg_interface = {
/* .get_name = */ ggml_backend_metaliium_reg_get_name,
/* .get_device_count = */ ggml_backend_metalium_reg_get_device_count,
/* .get_device = */ ggml_backend_metalium_reg_get_device,
/* .get_proc_address = */ NULL,
};
// Coresponding structure in ggml-backend-impl.h
// Edited to make it more sutable for a blog post.
struct ggml_backend_reg_i {
const char * (*get_name)(ggml_backend_reg_t reg);
size_t (*get_device_count)(ggml_backend_reg_t reg); // enumerate available devices
ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index);
// (optional) get a pointer to a function in the backend backends can add custom
// functions that are not part of the standard ggml-backend interface
void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name);
};
GGML follows what OpenCL and CUDA does. Each backend reports the number of devices available. the framework picks and chooses which to use based on information reported about device. Then initialize them separately. This however, does not work on TTNN as it needs to to initialize devices before information can be gathered. Oppose to say OpenCL where device information can be queried by calling clGetDeviceInfo()
without creating a context.
(Also TTNN needs special code to support multiple devices initialized at the same time. So the backend is limited to the 1st device for now.)
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metalium_reg()
{
static ggml_backend_reg reg;
static std::once_flag once;
std::call_once(once, [&]() {
tt::tt_metal::detail::EnablePersistentKernelCache();
static std::unique_ptr<ggml_backend_metalium_reg_context> ctx =
std::make_unique<ggml_backend_metalium_reg_context>();
// initialize and collect device info here
...
reg = ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .interface = */ ggml_backend_metalium_reg_interface,
/* .context = */ ctx.get()
};
});
return ®
}
Tensor creation and management
GGML computes the required buffer sizes from the compute graph then attempts to create memory pools on device. Backends are expected to reserve a chunk of memory then let GGML write data into them. GGML would only make two pools on startup and the are reused. One for weights and onc for intermediate buffers. There is no API for GGML to allocate temporary tensors, all allocations are static.
static ggml_backend_buffer_t
ggml_backend_metalium_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size);
This presents several challenges that needs to be overcome:
- TTNN doesn't use row major tensors. Can't directly write into a pointer
- Metalium doesn't even support mapping device memory into host. Nor it's technically possible at times (Devices connected via Ethernet)
- Allocation in both Metalium and TTNN is typed. Which information is not available in the standard API
- TTNN has it's own quantization types and is not compatiable with GGML's
"We are so fucked" would be a reasonable reaction. Every bullet point above is a violation of the fundamental programming contract. I felt the same for a few days. Turns out luckily one of the backends already faced nearly the exact same issue and had a solution. The SYCL backend under horizontal multi device mode also can't return a valid address since tensors are chopped onto different devices. Let's see what they did..
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
// the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
return (void *)0x1000;
GGML_UNUSED(buffer);
}
static void
ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor) try {
...
// ggml_tensor_extra_gpu holds metadata to track buffers without mapping
// to host address space
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
}
A dummy address is returned and never used. Instead a pointer to custom data is stored in tensor->extra
. The famous quote from Butler Lampson cannot be more correct - "All problems in computer science can be solved by another level of indirection". Some extra code is added in the Metalium backend to aid debugging in the early phase. Instead of returning a fixed dummy address. A offset is added to to make each address unique. Making it possible to reverse search which buffer is causing trouble in a crash.
Isolating actual allocation from GGML itself solves the programming model mismatch. However now deallocation also has to be handled explicitly else we got the unacceptable problem of device side memory leaking. Like said above GGML initializes 2 memory pools during inference. A large pool to hold model weights. And a smaller one to hold any intermediate state. The second pool will have it's reset
function in the management interface called after each round of inference. Giving a chance to clean up anything allocated in the mean time (so do store the newed pointer someone to delete them later).
static struct ggml_backend_buffer_i ggml_backend_metalium_buffer_interface = {
...
/* .reset = */ ggml_backend_metalium_buffer_reset, // <--- this guy called when GGML wants to clear pool
};
Getting data in and out
With creation working, the next order or business is to get data into and out of TTNN. Read and writes to tensors are controlled by mainly 3 functions. init_tensor
, set_tensor
and get_tensor
. Due to the nature of deferring and typed allocation. The Metalium backend allocates in the set_tensor
function, when the data type is now known. And intermediate buffers are allocated as needed in each operator's implementation. Completely reversing GGML's memory pooling.
It should be notes that, it is absolutely doable to restore the pooling behavior. Metalium though needs type information during allocation. That information is used to calculate alignment and other information, which a safe value can be applied by the backend directly and manage the address space itself. Though API to support such scheme does not exist in Metalium and I'm not about to open that can of worms. Maybe some time after I get the entire thing production ready.
Either way, with device memory at hand, data from GGML gotta be uploaded. But wait. Tenstorrent devices uses their own quantization types and is different from what GGML uses. The solution is obvious, convert GGML quantization to standard FP32, convert the floating point array to a TTNN tensor, upload to device, tilize it (remember row-major is not efficient for hardware) and cast to the desired final type. The code is nasty, so the pesudo code will have to do:
static void ggml_backend_xxx_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor,
const void *data, size_t offset,
size_t size)
{
auto * bufctx = (ggml_backend_metalium_buffer_context *)buffer->context;
std::vector<float> vec(ggml_nelements(tensor));
const ggml_type_traits* trait = ggml_get_type_traits(tensor->type);
if(trait->is_quantized) {
trait->to_float(src, vec.get(), size);
}
else {
// deal with different floating point types here
}
auto buffer = tt::tt_metal::owned_buffer::Buffer<Dst>(std::move(vec));
auto storage = tt::tt_metal::OwnedStorage(std::move(storage));
tt::tt_metal::Tensor t(std::move(storage), ttnn::Shape(shape)
, tt::tt_metal::DataType::FLOAT32, tt::tt_metal::Layout::ROW_MAJOR);
tt::tt_metal::DataType final_type = convert_to_final_tt_type(tensor->type);
t = ttnn::tilize_with_zero_padding(t.to(bufctx->device), std::nullopt, final_type);
...
// Keep on attaching the result TTNN tensor to the GGML tensor
}
The process in reverse has happens to get data from TTNN back into GGML. Whatever is stored in TTNN has to be converted into floating point, untilize. Then get GGML to convert it into the desired type. Though not used during inference, GGML will read quantized data as a part of it's unit tests. To ensure weights are not corrupted due to out of bounds writes in operations.
Running operations
This is what backends are for. To make running neural network operations fast. Backends need not to support every operation possible. GGML will check with each backend and schedule operators accordingly. Weather if an operator is supported by a backend is reported by supports_op
. The function simply returns false to signal GGML that it does not support an operation. A special operator GGML_OP_NONE
denotes an actual tensor and is easy to miss causing hours of head scratching, looking at cryptic error messages.
For example, if a backend wish to support only adding 1D tensors together:
static bool ggml_backend_xxx_device_supports_op
(ggml_backend_dev_t device, const struct ggml_tensor * op) {
switch(op->op) {
// You must accept NONE (the actual tensor data). Rejecting them will
// lead to falling back to CPU. Which may be desired if the backend
// does not support certain data types
case GGML_OP_NONE:
return true;
case GGML_OP_ADD:
return ggml_n_dims(op->src[0]) == 1 && ggml_n_dims(op->src[1]) == 1;
default:
return false;
}
}
After operation allocation, the graph_compute
interface is called with an array of nodes to run in order. The backend is responsible for dispatching operations to it's implementations.
static enum ggml_status ggml_backend_metalium_graph_compute
(ggml_backend_t backend, struct ggml_cgraph * cgraph)
{
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
switch (node->op) {
case GGML_OP_NONE:
break;
case GGML_OP_ADD:
// Call the actual implementation of the operator
ggml_backend_metalium_add(node->dst, node);
break;
default:
// Should never be executed since we rejected
// all other ops.
GGML_UNREACHABLE();
}
}
return GGML_STATUS_SUCCESS;
}
An operator implementation in the Metalium backend could look like the following (with annotations):
static void ggml_backend_metalium_add(struct ggml_tensor * dst, ggml_op op) {
// Add has to operands, stored in src[0] amd src[1]
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
// Grab the `extra` created during allocation, this has metadata we need later
TensorWithMetadata* meta0 = (TensorWithMetadata*)src0->extra;
TensorWithMetadata* dst_meta = (TensorWithMetadata*)dst->extra;
// realize_ggml_view() deals with view handling (explained in a later section)
// this function returns raw TTNN tensors if the soruce tensor is not a view
auto src_tensor0 = realize_ggml_view(src0);
auto src_tensor1 = realize_ggml_view(src1);
// assign to destination tensor
*dst_meta = {
.tensor = ttnn::add(*src_tensor0, *src_tensor1), // add
.ggtype = dst->type,
.bufctx = meta0->bufctx
};
}
Dealing with views
Most accelerators, including the chips made by Tensorrent use some sort of tiled layout to improve hardware efficency. It leads to yet another programming model mismatch. Most deep learning frameworks use strides to represent views and permutatins of a tensor. Strides are very general, powerful and really cheap to manipulate (changing a metadata vs the entire tensor). However, strides only works when the underlying data is layed out in row major order. Tensors in TTNN are in tiled format so plain strides are a no-go. The following link is a good read if you are not already familare with strides.
There's 2 routes to solve the situation.
- Eagerly evaulate all view and view-like operations into new tensors
- Evaulate views on upon downstream operator needing it (lazy evaluation)
The Metalium backend initially implememnts eager evaulation as it's easier to write. Later on the simpler design proves itself to be more a problem then a solution and has switched to lazy evaulation. From time to time, GGML will attempt to write data through views, which is natural when the data is pointed by a pointer (ex: writing to part of the kv-cache). But is impossible to do under eager evaluation of views. There are also some edge cases around copying viewed tensors from the Metalium backend to CPU. Since the CPU expects the underlying data to not be changed, just metadata changes, the original, non-viewed tensor must be copies instead of the viewed one. These kinds of bugs are subtle and a pain in the bu** to figure out. The reversed darn problem exists for setting tensors. Deal with it.
static void ggml_backend_metalium_buffer_get_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor *tensor,
void *data, size_t offset,
size_t size)
{
std::shared_ptr<tt::tt_metal::Tensor> t; // The TTNN tensor that will be copied to GGML
// Handle different views that GGML expects underlying data to not change
if(tensor->op == GGML_OP_TRANSPOSE) {
t = realize_ggml_view(tensor->src[0]);
}
else if (tensor->op == GGML_OP_PERMUTE) {
// DITTO above.
...
}
else if (tensor->op == GGML_OP_RESHAPE) {
// ......
}
else {
t = realize_ggml_view(tensor);
}
// actually performing the copy
...
}
Missing GGML features
The above is the general flow of writing a backend. During the development of the Metalium backend. I find some optimizations to be impossible in GGML. And really wish someone could implement them. If someone reading this post want to give it a shot, please do. I'll be happy to help in any way I can.
Compute graph rewrite
GGML completely lacks a graph rewrite stage. Compute graph is created then scheduled to backends and runs as-is. This presents missed optimization opportunities immidiately. Like fusing activations with fully connected layers to remove unnessary kernel lunches, reducing latency, which is a trivial optimization that even Caffe supports.
TTNN has support to store tensors on the on chip SRAM. Howevr, we can't simply dump every intermediate tensor on SRAM due to it's small size (1MB per core!). Without graph rewrite thus guaranteeing operator execution order, only intermediate states within each operator can be safely stored on chip, with usecase is far and between. Likewise, there's work done to prefetch weights of the next operation onto SRAM while the current one is still running - pushing the roofline by mixing both memory and compute bound tasks together - would also be impossible to enable in GGML due to lack of rewrite.
Parallel weight upload
The Metalium backend is unique for the need so I don't blame them for lacking such capablity. Model weights are usually quantized and due to TTNN demanding use of their native quantized types. The backend is forced to dequantize to floating point, upload to device then quantize to Tenstorrent's native types. Though not horribly slow, is still far from saturating the PCIe/Ethernet bandwidth and leads to unwanted waiting time loading larger models. Currently GGML uses a single thread to upload weights. Although it is possible to parallelize the process - for larger weight tensors, simply get multiple threads to dequantize and upload different parts of the tensor - the overhead is too high for smaller tensors thus smaller models. And the accumulated of overheads of distributing the work to multiple threads over and over again is not ideal. The parallelization should be done at the model level, enabling much better utilization of CPU time and bandwidth.
Model load time on a QuietBox to a single N300 card is as follows (AMD EPYC 8124P 2.2GHz connecting to the N300 with PCIe Gen4 x16. Note: only parts of the model is loaded to N300 due to missing operator support, thus some fallback to using CPU). Ideally LLaMA 3.2 8B should be loaded in less then 1s (at PCIe bandwidth speeds copying BFP16 tensors).
Model | Quantization | time (seconds) |
LLaMA 3.2 8B | Q4_K_M | 8.542 |
Tiny LLaMA 1.1B | Q4_0 | 2.078 |
Gemma 2B | Q4_K_M | 4.740 |
Debugging tricks
Some tricks I find really helpful while building the backend.
Dimension order, fixed dimension arrays and naming
GGML stores dimension and strides in REVERSE order! Keep this fact in mind. A tensor with shape [2, 3, 4] will result in tensor->ne
of [4, 3, 2 1]. This is by design to make coding easier as no rever loop is needed. Likewise, GGML defines a GGML_MAX_DIMS
macro to be 4. A tensor always have 4 dimensions, even if it's a 1D tensor. Just the higher dimensions are 1. This again is to make coding easier. TTNN and libraries like ArrayFire also uses fixed dimension arrays. Keep these in mind when developing.
GGML's naming is concise but need to get used to. tensor->ne
stands for N elements. This is the shape of the tensor. While tensor->nb
stands for N bytes, the stride of the tensor. However, strides does not work alone. You need to query ggml_get_type_traits(tensor->type)
to get the size of the data type and how many elements are in the block. For example, a tensor of type Q4_0 has size of 18 bytes and contains 32 elements. At minimum, you need to process the entire block of 32 elements at a time.
Spam asserts
Both TTNN and GGML has their own tensor class/struct which SHOULD carry the same information. But duplicated data is always a source of desync and pain. SPAM asserts everywhere to check GGML agrees with TTNN's opinion on the tensor. Shape and everything. Espcially because TTNN is still a young project and things chang every day. Making an early crash even more important to maintain the developer's sanity.
For example checks like these are all over the backend. Making sure TTNN is behaving the same way I as the day I wrote the corresponding code
GGML_ASSERT(row_major_tensor.storage_type() == StorageType::OWNED or row_major_tensor.storage_type() == StorageType::BORROWED);
GGML_ASSERT(std::holds_alternative<OwnedStorage>(row_major_tensor.storage()) || std::holds_alternative<BorrowedStorage>(row_major_tensor.storage()));
llama-eval-callback
llama-eval-callback
is a program built along with the main llama-cli
that people uses for actual inference. llama-eval-callback will take in prompts, load models then run the inference folow for a single token. During which, prints all tensors and some metadata to console. To debug your backend, run once with CPU only, save all outputs to a text file, another run with your backend (by setting -ngl
to a non 0 number) then diff the output. Literally diffing tensors to figure out where your backend starts divering significantlly from the CPU's results.
llama-eval-callback is made such no randomness seeps into the results at the framework level. So no need to worry about seeds and other problems around randomness.
Results
As of writing this post, the Metalium backend can run smaller LLMs on both Grayskull and Wormhole. However, missing or limited operatiors in TTNN is causing lots of falling back to the CPU. Beyond the cost of data transfer, TTNN also needs to tilize/untilize each time this happens, leading to further performance degrdation. I am in contact with Tenstorrent engineers and they are very supportive of the GGML effort. I am hopoeful that missing operators can be fixed in the coming months.
Anyway. here's a screenshot of LLaMA 3.2 1B running on a Wormhole N300 card (again, performance is not good until TTNN adds more op support)
And with all the informational output from llama.cpp supressed
Current performance on a QuietBox using one N300 card is as follows (yet again, the point is to get them working, speed will be the next objective), Note that Tenstorrent hardware nativelly supports batch=32. So in reality, the performance is another 32 times higher then in single batch mode.
Model | Quantization | TTNN quant type | tok/s |
LLaMA 3.2 8B | Q4_K_M | BFLOAT8_B | 3.56 |
LLaMA 3.1 1B | Q4_K_M | BFLOAT8_B | 22.18 |
TinyLLaMA 1.1B | Q4_0 | BFLOAT8_B | 21.85 |
Gemma2 2B | Q4_K_M | BFLOAT8_B | 11.14 |
Gemma2 2B | Q4_K_M | BFLOAT4_B | 11.40 |
I would say that making the integration is less then half of the real task. The majority of time is isolating problems, make damn sure they are TTNN bugs, report and communicating with TTNN developers about needs and issues. I've been spending more time lately contributing and fixing lower level problems in TTNN that directly impacts GGML.
The game pain is to eventually upstream all my changes. But that is still some ways out.
Conclusion
It is not overly complicated to get new hardware and SDKs working in GGML, not saying it's easy. However, I should have outlined the general direction and hopefully people won't have to rediscover what is written in this post. There is also, missing functionality that would empower GGML to better utilize novel hardware. Which I hope people would take on the challage (before I am at the point I gotta to hehe..). It's also exciting how scaling should be easy on Tenstorrent hardware. The Metalium backend should be able to support muti device with ease once a few blockers are moved.
Anyway, feel free to check out my code. And if you are coming to FOSDEM. Please definatelly be at my talk under the same title, during the 2nd day at 3PM. Will be thrilled to see everyone.
Now I gotta finish my slides. Someone please give me an entire pallet worth of energy drinks. I'm dying.
Martin Chang
Systems software, HPC, GPGPU and AI. I mostly write stupid C++ code. Sometimes does AI research. Chronic VRChat addict
I run TLGS, a major search engine on Gemini. Used by Buran by default.
- marty1885 \at protonmail.com
- Matrix: @clehaxze:matrix.clehaxze.tw
- Jami: a72b62ac04a958ca57739247aa1ed4fe0d11d2df