Thoughts and logs after messing with Tenstorrent Grayskull

I got my Tenstorrent card last week or so, and I set it up and gave it a test drive. My end goal is to develop it's software stack and applications such as it can be used as a replacement for Nvidia GPUs, for cheap and at a lower power consumption. But for now, it's time to get my hands wet and see what it can do as it is.

You can get a glimpse of what their hardware is like in their programming guide:

Setting up

I'm an Arch user and I really want to use it for my Tenstorrent development rig. Mainly because Arch is a rolling distro so I won't need to care about upgrading the system and breaking stuff all at once. But, of course, Arch is not officially supported by Tenstorrent, so I had patch their code to make it work. And wrote PKGBUILD packages for it. Stuff will be gradually upstreamed as I get more familiar with the codebase and wrote higher quality fixes.

tt-metalium

Metalium is their equlivent to OpenCL and CUDA.

Metalium is really in it's early days. When I first started, I had to spend more time at debugging processor hang and rebooting the entire machine than actually writing code. Like the early days with AMD APP, you accidentally wrote a kernel that won't end - you are out of luck. Gotta REISUB or blindly type on the terminal to reboot the machine. Good that Tenstorrent cards doesn't drive the video output and can be reset using their tt-smi tool.

One reason why I'm interested in Metalium is because it gives you complete access to the hardware - also means I can write my own code and push the hardware to see what it can do.

As of now I've upstreamed an well annotated example so hope it helps others to get started with Metalium.

Hanging in metalium is a lot easier then in OpenCL. In OpenCL, the only way to hang is by an infinite loop. However, due to Tenstorrent's architecture, I believe both code and data is stored on the same SRAM without any hardware protection. So if you write a kernel that writes to an invalid memory location, there is a non-zero chance that you execute bad instructions and break the entire control flow. Further more, due to the intra-Tensix communication being done through circular buffers. It is really easy to mis-configure the circular buffer and endup in a deadlock. Both situations will require a board reset.

Programming Metalium, a gist

Programming tenstorrent hardware in Metalium is an unique experience. On Tenstorrent, the minimal unit of computation is a 32x32 matrix. Called a tile. The Graysekull is make of a grid of Tensix cores, each Tensix contains 5 (I assume) single issue RISC-V cores. 2 of them are used to copy data in and out of the Tensix and 3 of them are used to control the SIMD engine and the tensor engine. Both, operates on a tile-by-tile basis.

Note the wording here "control" the SIMD and tensor engine. Not "utilize". Instead thinking in the traditional processor sence, where the each CPU core has it's own SIMD. On Tenstorrent, the 3 RISC-V cores copy data from SRAM to the SIMD engine, issues instructions to the SIMD engine and then copy the results back to the SRAM. The SIMD/Tensor engine is not a part of the RISC-V core, but a separate unit.

Due to the fact that the SIMD/tensor engine is not a part of the RISC-V. Communication with them are explicit. There are 16 tile registers on each Tensix core. You must explicitly load data from the SRAM to the tile register and out. And before using them, you must acquire the tile register so other cores won't happen to be using it at the same time. You'll see what I mean.

The following Metalium program reads 2 sets pf 8 tiles from the DRAM and add them together. The result is written back to the DRAM. This actually runs on all 5 cores. The 2 data movement cores runs read_data_in.cpp and write.cpp and the 3 SIMD cores runs compute_add.cpp (through some C++ macro magic).

// read_data_in.cpp
void kernel_main()
{
    uint32_t a_addr = get_arg_val<uint32_t>(0);
    uint32_t b_addr = get_arg_val<uint32_t>(1);
    uint32_t n_tiles = 8;

    constexpr uint32_t cb_in0 = tt::CB::c_in0;
    constexpr uint32_t cb_in1 = tt::CB::c_in1;

    const uint32_t tile_size_bytes = get_tile_size(cb_in0);

    const InterleavedAddrGenFast<true> a = {
        .bank_base_address = a_addr,          // The base address of the buffer
        .page_size = tile_size_bytes,         // The size of a buffer page
        .data_format = DataFormat::Float16_b, // The data format of the buffer
    };
    const InterleavedAddrGenFast<true> b = {
        .bank_base_address = b_addr,
        .page_size = tile_size_bytes,
        .data_format = DataFormat::Float16_b,
    };

    for(uint32_t i = 0; i < n_tiles; i++) {
        cb_reserve_back(cb_in0, 1);
        cb_reserve_back(cb_in1, 1);
        uint32_t cb_in0_addr = get_write_ptr(cb_in0);
        uint32_t cb_in1_addr = get_write_ptr(cb_in1);
        noc_async_read_tile(i, a, cb_in0_addr);
        noc_async_read_tile(i, b, cb_in1_addr);
        noc_async_read_barrier();
        cb_push_back(cb_in0, 1);
        cb_push_back(cb_in1, 1);
    }
}

// compute_add.cpp
namespace NAMESPACE {
void MAIN {
    uint32_t n_tiles = get_arg_val<uint32_t>(0);

    constexpr auto cb_in0 = tt::CB::c_in0;
    constexpr auto cb_in1 = tt::CB::c_in1;
    constexpr auto cb_out0 =  tt::CB::c_out0;
    constexpr uint32_t dst_reg = 0;

    binary_op_init_common(cb_in0, cb_in1, cb_out0);
    add_tiles_init();

    for(uint32_t i = 0; i < n_tiles; i++) {
        acquire_dst(tt::DstMode::Half);
        cb_wait_front(cb_in0, 1);
        cb_wait_front(cb_in1, 1);
        add_tiles(cb_in0, cb_in1, 0, 0, dst_reg);
        cb_reserve_back(cb_out0, 1);
        pack_tile(dst_reg, cb_out0);
        cb_push_back(cb_out0, 1);
        cb_pop_front(cb_in0, 1);
        cb_pop_front(cb_in1, 1);
        release_dst(tt::DstMode::Half);
    }
}
}

// write.cpp
void kernel_main()
{
    uint32_t c_addr = get_arg_val<uint32_t>(0);
    uint32_t n_tiles = get_arg_val<uint32_t>(1);

    constexpr uint32_t cb_out0 = tt::CB::c_out0;
    const uint32_t tile_size_bytes = get_tile_size(cb_out0);

    const InterleavedAddrGenFast<true> c = {
        .bank_base_address = c_addr,
        .page_size = tile_size_bytes,
        .data_format = DataFormat::Float16_b,
    };

    for(uint32_t i = 0; i < n_tiles; i++)
    {
        cb_wait_front(cb_out0, 1);
        uint32_t cb_out0_addr = get_read_ptr(cb_out0);
        noc_async_write_tile(i, c, cb_out0_addr);
        noc_async_write_barrier();
        cb_pop_front(cb_out0, 1);
    }
}

The bad part about the approach is the fact that the code is forced to have detatched compute and dataflow. Making it difficult to program. However, this enables software pipelining on the hardware level. The following is a rough idea of how the code above runs on the hardware.

|     Data Movement 0   |      Compute 0      |      Comute 1      |      Compute 2      |   Data Movement 1   |
|      DRAM -> SRAM     |                     |                    |                     |                     |
|      DRAM -> SRAM     |  SRAM -> tile reg   |                    |                     |                     |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |                     |                     |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |                     |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|      DRAM -> SRAM     |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|                       |  SRAM -> tile reg   |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|                       |                     |   issue tile add   |  tile reg -> SRAM   |     SRAM -> DRAM    |
|                       |                     |                    |  tile reg -> SRAM   |     SRAM -> DRAM    |
|                       |                     |                    |                     |     SRAM -> DRAM    |

I wonder if there is a way to turn it into a more traditional programming model. Either with more macro magic or an intermidiate language that compiles down to Metalium.

Comments on the programming model

I have mixed feelings about the programming model. On one hand, this is really clever. They did software pipelining with minimal hardware. No crazy LOOP instrunctions like DSPs nor GPU-like dynamic warps. However, it is also really difficult to reason about. They did provide synchronization primitives to make things sane-ish. But even with years of prallel and GPU programming experience, I still find it difficult to effectively program the hardware and extract the most parallelism out of it. Usually you are bottlenecked by the RISC-V cores being slow and bottlenecking tight loops or tile computation takes too long and you underutilize the data movement. Which devolves performance back to what the GPU model provides.

Hardware limitations

The design of operating on 32x32 matrices is an interesting one. I blieve this is to make convolutions easier. However, this also basically forces language models to run at batch=32 all the time. (You can run at batch=1. but then hardware utilization drops to 1/32). This is a bit of a bummer for me. Being able to fully utilize the matrix multiplcation engine at batch=1 would be a killer feature for me. But I guess I'll have to wait for the next generation of Tenstorrent hardware.

tt-BUDA

BUDA is the high-level stack that Tenstorrent provides. It loads models form PyTorch, TensorFlow and ONNX and runs them on the card. It's a bit like Nvidia's TensorRT or ARM's ARM-NN.

Getting the prebuilt BUDA to work on Arch is kind of a hassle. Mainly because Tenstorrent only supports Ubuntu 20.04 officially while I use Arch. When I first ran, it complained about not able to find specific versions of Boost and yaml-cpp. I hacked around it by looking into what exactly is missing using ldd and compile the correct versions of Boost and yaml-cpp myself. Then dump the shared libraries into the virtual environment. It's a bit of a hack, but it works.

(buda) ➜  ~ python
Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:38) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pybuda
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/marty/micromamba/envs/buda/lib/python3.8/site-packages/pybuda/__init__.py", line 42, in <module>
    from .module import Module, PyTorchModule, PyBudaModule, TFModule, TFGraphDefModule, OnnxModule, MXNetModule, JaxModule, TFLiteModule
  File "/home/marty/micromamba/envs/buda/lib/python3.8/site-packages/pybuda/module.py", line 14, in <module>
    from .pybudaglobal import register_module, lazy_trace_data
  File "/home/marty/micromamba/envs/buda/lib/python3.8/site-packages/pybuda/pybudaglobal.py", line 16, in <module>
    from pybuda._C.backend_api import BackendType
ImportError: libboost_serialization.so.1.71.0: cannot open shared object file: No such file or directory

Loading ONNX models into BUDA

It took me a bit of time to figure out how to use BUDA. The documents does not contain examples on to get ONNX working with it. And the API is confusing for the least.

import pybuda
import onnx

# Yeah, you need to load the ONNX model then pass it to BUDA, along with the path again
model_path = "/path/to/your/model.onnx"
onnx_model = onnx.load(model_path)
buda_model = pybuda.OnnxModule("module_name", onnx_model, model_path)

# Now call pybuda.run_inference() to infer the model. 
# Note the inputs is an array of arrays. Each array is a batch of inputs.
# And for some reason it must be a torch tensor.
inputs = [[1, 2, 3, 4], [5, 6, 7, 8]]
for i in range(len(inputs)):
    inputs[i] = torch.tensor(inputs[i])
out = pybuda.run_inference(buda_model, inputs)

The ugly part of BUDA

Unfortunately, just like Rockchip's RKNN compiler, almost everything that is not in the official demos doesn't work. Their employees does say that they are working on enabling more features. And their engineering team have delivered a lot of features in a short time. So my hopes are up that with some time, BUDA will be good enough to run most models.

Author's profile. Photo taken in VRChat by my friend Tast+
Martin Chang
Systems software, HPC, GPGPU and AI. I mostly write stupid C++ code. Sometimes does AI research. Chronic VRChat addict

I run TLGS, a major search engine on Gemini. Used by Buran by default.


  • marty1885 \at protonmail.com
  • Matrix: @clehaxze:matrix.clehaxze.tw
  • Jami: a72b62ac04a958ca57739247aa1ed4fe0d11d2df