Rockchip NPUs and deploying scikit-learn models on them

My first experience with my RK3588 board was mildly infuriating. I bought my Orange Pi 5Plus for it's quite capable NPU. However the low level matrix multiplication API segfaults every single time. After a long period of headbanging I decided to dump that approach for now and the barely working rknn-toolkit2 high level interface. Even that has it own set of ridiculous problems. I thought converting scikit-learn, the most basic and widely used ML library, would be a breeze. I was wrong. I ended up writing my own converter. With this experience I'll be able to tackle larger and more useful models in the future.

TL;DR (if you just want to use it)

The converter I wrote, scirknn is hostd on GitHub. The core of the project are 2 python files. sklearn2rknn.py and scirknn.py. The former converts scikit-learn models to rknn-toolkit2's format. The latter is a wrapper around rknn-toolkit2 so it behaves like scikit-learn's MLPClassifier/MLPRegressor.

Let's say you have some scikit-learn MLPClassifier. You can convert it to rknn-toolkit2's format in 2 ways. 1. by calling sklearn2rknn.convert or 2. invoking sklearn2rknn as a script.

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)

clf = MLPClassifier(random_state=1, max_iter=300)
clf.fit(X_train, y_train)

To convert it via sklearn2rknn.convert:

import sklearn2rknn
target_platform = 'rk3588'
sklearn2rknn.convert(clf, 'iris.rknn', target_platform)

Alternatively, save the model into iris.pkl then invoke sklearn2rknn as a module:

python3 -m sklearn2rknn iris.pkl iris.rknn --target_platform rk3588

Either way, you should end up with iris.rknn and iris.rknn.json. These are the actual model and the model's metadata. The metadata generated by sklearn2rknn for the wrapper. To use the model, you can use scirknn as a drop-in replacement for scikit-learn's MLPClassifier. But first, copy iris.rknn and iris.rknn.json to your dev board.

import scirknn
clf = scirknn.MLPClassifier('iris.rknn')
clf.predict(X_test)

Quantization

The NPU itself on RK3588 supports operating on many data types. For instance, regualr 32bit floating point, int16, int8, float16 and even int4. However, rknn-toolkit2 only supports int8 and float16. By default without quantizing, the model will be converted to float16. Under this mode, the NPU has a peak performance of 1.5TOPS (but with many asterisks. Most operators besides convolution and data movement does not support multi-core co work). Currently there's no way to fully utilize the 6TOPS of compute as int4 quantization is not supported by the conversion process. The best rknn-toolkit2 can do is 3TOPS with int8 (at the cost of some accuracy, again, asterisks applies). To do so, call sklearn2rknn.convert with the quantization argument and provide an example dataset. RKNN uses the provided dataset to calabrate the quantization.

sklearn2rknn.convert(clf, 'iris_quantized.rknn', target_platform, quantization=True, example_input=X_train[:100])

Quantization does not change how inference is done.

import scirknn
clf = scirknn.MLPClassifier('iris_quantized.rknn')
clf.predict(X_test)

The RK3588 NPU and rknn-toolkit2

NPU documentation (which some are in Chinese, good that I do speak it) heavily implies that it is designed as a vision model processor. What I'm doing is really abusing it's capability. Not that I'm going to care, that's the fun right? However, this also means that models must adhere to some strict requirements less stuff starts to run on the CPU.

NPU architecture

The best way to describe is the NPU on the RK3588 is a fixed pipeline dataflow processor with 3 of them on each RK3588 chip. It looks almost like the processor I worked on in collage. But with much simpler control schemes and more flexible dataflow. Not saying the control interface is simple, but it's not executing a program. Instead, there's a huge set of registers that controls how data goes in and out of the NPU. Then once execution starts, the dataflow is fixed. It does not change until the next execution. If you are familiar Texas Instrument's C7x DSP, it's similar to the matrix unit + streaming engine. But directly exposed to the main CPU instead of being a coprocessor of the DSP.

Block diagram of a RK3588 NPU core
Image: Block diagram of a RK3588 NPU core

The NPU core runs at 1GHz and can perform 2048 int4 operations per cycle, 1024 int8 per cycle or 512 fp16 per cycle. The NPU is also muticore. Each RK3588 SoC comes with 3 NPUs. Which adds up to 6TOPS of compute. However, due to rknn-toolkit2 not supporting quantization to int4, the best we can do is 3TOPS with int8. Even that comes with asterisks. As of RKNN 1.5.0, only convolution and some data movement operators support multi-core. All other operations run on a single core. That includes matrix multiplication, LSTM, GRU, etc.. Thus, if we where attempt to run language models on the NPU, we'll be limited to at best 1TOPS @ int8 or 500 GFLOPS @ fp16.

CPU fallbacks

Due to the NPU being very static, it's not possible to just run arbitrary layers on it. Strict alignment and size constrants apply. When a layer not matching the strict requirement, RKNN runs the layer on the CPU instead. Of course, this is not ideal. Thus we ought to avoid as much as possible.

The detailed list and requirements can be found in the compiler operation manual. But it's in Chinese.

Static Shapes

rknn-toolkit really, really does not support dynamic shape inference. Unlike almost every inference engine out there. Which will run a model with any input shape as long as it's within the model's input shape range (ex. have an arbitrary batch size). RKNN needs to know the exact shape of the input when compiling the model. This usually means fixing the batch size to 1. And make the model's input shape a constant.

Passing None as batch size results in error when compiling the model.

W __init__: rknn-toolkit2 version: 1.5.0+1fa95b5c
W load_onnx: Onnx opset14 is not fully supported, it may cause convert fail, it is recommended to use opset12!
E load_onnx: The shape ['', 2] of 'output0' is not support! Please use the 'inputs' / 'input_size_list' of load_onnx to set the correct shape!
W load_onnx: ===================== WARN(2) =====================
E rknn-toolkit2 version: 1.5.0+1fa95b5c
E load_onnx: Catch exception when loading onnx model: /tmp/C24UVSTUWR8M0FUS.onnx!
E load_onnx: Traceback (most recent call last):
E load_onnx:   File "rknn/api/rknn_base.py", line 1382, in rknn.api.rknn_base.RKNNBase.load_onnx
E load_onnx:   File "rknn/api/rknn_base.py", line 658, in rknn.api.rknn_base.RKNNBase._create_ir_and_inputs_meta
E load_onnx:   File "rknn/api/ir_graph.py", line 58, in rknn.api.ir_graph.IRGraph.__init__
E load_onnx:   File "rknn/api/ir_graph.py", line 453, in rknn.api.ir_graph.IRGraph.rebuild
E load_onnx:   File "rknn/api/rknn_log.py", line 112, in rknn.api.rknn_log.RKNNLog.e
E load_onnx: ValueError: The shape ['', 2] of 'output0' is not support! Please use the 'inputs' / 'input_size_list' of load_onnx to set the correct shape!
W If you can't handle this error, please try updating to the latest version of the toolkit2 and runtime from:
  https://eyun.baidu.com/s/3eTDMk6Y (Pwd: rknn)  Path: RK_NPU_SDK / RK_NPU_SDK_1.X.0 / develop /
  If the error still exists in the latest version, please collect the corresponding error logs and the model,
  convert script, and input data that can reproduce the problem, and then submit an issue on:
  https://redmine.rock-chips.com (Please consult our sales or FAE for the redmine account)

This works for vision models as the input shape is always the same. But a pain for audio and language models. As usually the input length is how many tokens/phonemes/frames there are. Which is not known until runtime. Remember the pain of training recurrent models on old TensorFlow? It's back. Take YOLOv5 for example. It takes an image of [1x3x640x640] as input and output 3 [1x255,80,80] tensors. But VITS (speech synthesis) takes input of [batch_size, phonemes, all_known_phonemes]. Of which phonemes is not known until runtime. While also be impractical to pre-define as it either limits the max length or wastes compute and memory.

Not saying this can't be worked around. In theory, it's possible to decompose a model into multiple RKNN models. Each with a fixed input shape. Then run them in sequence. But that's a pain in the butt. And have performance implications. Rockchip, please fix this.

Limited ONNX operator support

rknn-toolkit2 walks the ONNX graph and converts each operator to a RKNN operator. However, it does not support all ONNX operators. For instance, ZipMap is not supported. The list of supported operators can be found on GitHub. IMPORTANT: Support simply manes RKNN can execute the operator. It does not mean it can run on the NPU. For instance, Gemm is supported. But it can only run on the CPU.

Converting scikit-learn models to rknn-toolkit2

With the restrictions I need to work with out of the way. I can start getting into how to I converted scikit-learn models to rknn-toolkit2. The structure of the sklearn model is simply. Any MLP model have a few attributes we care. coeff_ and intercepts_ are the weight and bias. activation the activation function and out_activation_ the output activation function. The rest are just hyperparameters and stats.

Yet, it's easier said then done. The RKNN compiler is full of BS and you only know when it fails out of nowhere. So much so I ended up building ONNX graphs by hand and change my code to fit the compiler. The standard sklearn-onnx does not work due to using unsupported operators. Also the ONNX versioning and version conversion is messed up. rknn-toolkit2 will warn you if trying to load anything > then ONNX 12. However Add and various activation functions all fail to convert to <= v12. Luckily rknn-toolkit2 will accept v13 and a few versions above.

Building the ONNX graph

This is strightforward. Iterate the sklearn model and add MatMul and Add nodes. Then add activation functions. The only tricky part is naming the inertial tensors. Which I ended up being lazy and just use the index of the layer. And _tmp for the output to the activation function. Then _mul_tmp for the output of the Add node.

import onnx.helper as helper
import onnx

nodes = [] # This holds all the nodes in the ONNX graph
for i, (weight, bias) in enumerate(zip(model.coefs_, model.intercepts_)):
    # Name the weight and bias tensors to be used in the ONNX graph
    weight_name = f"weight{i}"
    bias_name = f"bias{i}"
    nodes += [helper.make_node("Constant", inputs=[], outputs=[weight_name]
        , value=helper.make_tensor(name=weight_name, data_type=onnx.TensorProto.FLOAT, dims=weight.shape, vals=weight.flatten().tolist()))]
    nodes += [helper.make_node("Constant", inputs=[], outputs=[bias_name]
        , value=helper.make_tensor(name=bias_name, data_type=onnx.TensorProto.FLOAT, dims=bias.shape, vals=bias.flatten().tolist()))]

    # If this is the last layer, use the output activation function. Otherwise use the hidden activation function
    # model.n_layers_ - 2 because sklearn considers the input layer as a layer
    act_func_name = activation_func_name if i < model.n_layers_ - 2 else activation_functions_map[model.out_activation_]
    # Can't use an Identity node as versuin converter will complain. Also we can't use GEMM as RKNN implements it in CPU,
    # However, MatMul + Add does run on the NPU
    if act_func_name == "Identity":
        nodes += [helper.make_node("MatMul", inputs=[f"output{i}", weight_name], outputs=[f"output_tmp_mul{i+1}"])]
        nodes += [helper.make_node("Add", inputs=[f"output_tmp_mul{i+1}", bias_name], outputs=[f"output{i+1}"])]
    else:
        nodes += [helper.make_node("MatMul", inputs=[f"output{i}", weight_name], outputs=[f"output_tmp_mul{i+1}"])]
        nodes += [helper.make_node("Add", inputs=[f"output_tmp_mul{i+1}", bias_name], outputs=[f"output_tmp{i+1}"])]
        nodes += [helper.make_node(act_func_name, inputs=[f"output_tmp{i+1}"], outputs=[f"output{i+1}"])]

# Now we have the ONNX graph. Let's build the model
graph = helper.make_graph(nodes, "scikit2rknn"
    , [helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, [batch_size, model.n_features_in_])]
    , [helper.make_tensor_value_info(f"output{n_layers-1}", onnx.TensorProto.FLOAT, [batch_size, model.n_outputs_])]
)
model = helper.make_model(graph)
# Call the checker to make sure the graph is valid
onnx.checker.check_model(model)
# Convert the model to v13. As rknn-toolkit2 does not like v18 (the default on my system)
model = onnx.version_converter.convert_version(model, opset_ver)

ONNX graph of a scikit-learn MLP converted by sklearn2rknn
Image: ONNX graph of a scikit-learn MLP converted by sklearn2rknn

I recommend the following article on towardsdatascience if you want to try building ONNX graphs by hand. It's actually fun to see how neural processors understand models under the hood.

From ONNX to RKNN

With ONNX ready, we can call rknn-toolkit2 to convert it to RKNN. This is strightforward. Just call rknn.load_onnx and rknn.build. However, there are a few quirks. Note that the build process is platform specific. So you need to specify the target platform. I'm using rk3588. But it works the same on other platforms as well.

from rknn.api import RKNN
rknn = RKNN()
rknn.config(target_platform='rk3588')
rknn.load_onnx(model=model)
rknn.build(do_quantization=False)
rknn.export_rknn("model.rknn")

In order for RKNN to quantize the model, you need to provide an example input to build(). Then save the model as usual.

rknn.build(do_quantization=True, dataset=[some_data])

Now the model can be loaded on your dev board with rknn-toolkit-lite.

from rknnlite.api import RKNNLite

rknn = RKNNLite()
rknn.load_rknn("model.rknn")
self.rknn.init_runtime()
out = rknn.inference(inputs=[some_data])
print(out) # yay! it works!!

Inference quirks

I've handled these edge cases in scirknn.py. But notewrothy to mention. RKNN inference is.. weird. For some reason, it does not care about the input shape. It just does with what it has. Also it's very sensitive to the data type. My assumption is this is related to how it works internally.

rknn = RKNNLite()
.... # Load a model with 2 inputs

x = np.array([[1, 2, 3, 4]], dtype=np.float32)
rknn.infershape(inputs=[x]) # Somehow this works?????
# the return value as if I only provided [1, 2] as input

But passing integers make the SDK spew out errors.

rknn.infershape(inputs=[[1, 4])]) # This blows up!

I RKNN: [10:10:55.646] RKNN Runtime Information: librknnrt version: 1.5.0 (e6fe0c678@2023-05-25T08:09:20)
I RKNN: [10:10:55.646] RKNN Driver Information: version: 0.8.5
I RKNN: [10:10:55.646] RKNN Model Information: version: 4, toolkit version: 1.5.0+1fa95b5c(compiler version: 1.5.0 (e6fe0c678@2023-05-25T08:18:57)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
E RKNN: [10:10:55.649] Normalize does not support for this data type. src type(6), dst type fbs::TensorType_FLOAT16
E RKNN: [10:10:55.649] rknn_inputs_set, normalize error(-1) index=0

Enforcing it to use float32 works.

rknn.infershape(inputs=[[1, 4], [1, 4]], data_type='float32') # This works!

Future work

From the logs, I suspect Rockchip will support dynamic shape inference in the future. But for now, I'll have to work with what I have. I'm planing on decomposing the RWKV language model into pieces and see if I can get it to run on the NPU. But that's a post sometime in the future.

Author's profile. Photo taken in VRChat by my friend Tast+
Martin Chang
Systems software, HPC, GPGPU and AI. I mostly write stupid C++ code. Sometimes does AI research. Chronic VRChat addict

I run TLGS, a major search engine on Gemini. Used by Buran by default.


  • marty1885 \at protonmail.com
  • Matrix: @clehaxze:matrix.clehaxze.tw
  • Jami: a72b62ac04a958ca57739247aa1ed4fe0d11d2df