Deploying scikit-learn MLP on Tenstorrent hardware

My llama.cpp fork for Tensotorrent hardware is going fairly well. But this weeked I want to rest a bit and do someting else. I really want some tiem to go and play WatchDogs2. But then I have a new idea. I programmed Rockchip NPUs to run MLPs from sckikit-learn. It should be just as easy to get them working on Tenstorrent via TTNN.

Oh yeah it is. No more ONNX hacking. No messing around with the pile of c**p that's the Rockchip compiler. Just invoke the TTNN operator library. The code is so much shorter and cleaner now. It's almost trivial. Unfortunately, TTNN can't direcly communicate with NumPy yet. So it needs to go through a very heavy conversion through torch (I heard we are working on that). Then load it to device.

The model conversion process is literally just the following:

class TTMLPRegressor:
    def __init__(self, device, scikit_model: MLPClassifier):
        self.weights = []
        self.biases = []
        self.activations = []
        self.device = device
        activation = scikit_model.activation

        for i, (weight, bias) in enumerate(zip(scikit_model.coefs_, scikit_model.intercepts_)):
            w = ttnn.from_torch(torch.tensor(weight), device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)
            b = ttnn.from_torch(torch.tensor(bias), device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)

            self.weights.append(w)
            self.biases.append(b)

            # scikit-learn does not apply actication on the final layer
            self.activations.append(activation if i < len(scikit_model.coefs_) - 1 else 'identity')

    # Returns a ttnn tensor
    def forward(self, x):
        x = ttnn.from_torch(torch.tensor(x), device=self.device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)
        for i, (w, b) in enumerate(zip(self.weights, self.biases)):
            x = ttnn.matmul(x, w) + b

            activation = self.activations[i]
            if activation == 'identity':
                pass
            elif activation == 'logistic':
                x = ttnn.sigmoid(x)
            elif activation == 'tanh':
                x = ttnn.tanh(x)
            elif activation == 'relu':
                x = ttnn.relu(x)
            else:
                raise ValueError(f"Unsupported activation function: {self.activation}")
        return x


    def predict(self, x):
        x = self.forward(x)
        return ttnn.typecast(x, ttnn.float32).cpu().to_numpy()

To use it. Create a model, train it and pass it to the library:

device = ttnn.open_device(device_id=0)

# Train a scikit model
# This model size is absurd for MNIST, but serves to illustrate performance differences
clf = MLPClassifier(hidden_layer_sizes=(2048, 2048), max_iter=500, alpha=0.0001, solver='adam')
clf.fit(X_train, y_train)

# Convert to TTNN
model = scittnn.TTMLPClassifier(device, clf)

# Now use it
pred_tt = model.predict(X_test)

Some quick benchmarking. Even for small (32 hidden neurons) models, TTNN is at least as fast as CPU. And large models like the (2048, 2048) above - TTNN is 25x faster! That's including the time TTNN needs to tilize and untilize input and output. Woohoo!

scikit-learn accuracy: 0.9805555555555555
scikit-learn prediction time: 0.022970 seconds
scittnn accuracy: 0.9805555555555555
scittnn prediction time: 0.000940 seconds
Speedup: 24.44x

I made this in an hour. It's not much work. Hope you learned something. I guess the usecase of this thing is if you are a company, runs a lot of small models, and thinks enterprise GPUs are way too expensive to be worth it but CPU being too slow - TT hardware is much cheaper without all the licensing madness. Let me know - so I can help my sales friends to close some :)

Code avaliable in the following link:

Author's profile. Made my my friend.
Martin Chang
Systems software, HPC, GPGPU and AI. I mostly write stupid C++ code. Sometimes does AI research. Chronic VRChat addict

I run TLGS, a major search engine on Gemini. Used by Buran by default.


  • martin \at clehaxze.tw
  • Matrix: @clehaxze:matrix.clehaxze.tw
  • Jami: a72b62ac04a958ca57739247aa1ed4fe0d11d2df