Writing CPU ML Kernels with XNNPACK

Writing CPU ML Kernels with XNNPACK

·
  • cpu
  • kernel

Even though the current state of ML computing looks like the image below, you should care about running your models on CPU. Because it’s damn convenient.

Nvidia winning meme


One H100 typically costs between $25’000 to $35’000. In contrary, there’s a CPU in about every device possible. Now obviously, CPUs are not well suited for serving a model over an API, but you’d be a fool to compete here. If you ship the model directly on your user’s hardware, then there’s no round-trip latency, no bandwidth issues, and you get privacy!

Now, as it is not the focus of this article, I won’t delve into the details of which models you should run, and which one will be fast. As a rule of thumb, models such as Qwen3 or LFM2 (which I helped develop!) are usually a good default.

How to run your models on CPU

There are many options to run your models on CPU, such as TensorFlow Lite, TensorFlow.js, PyTorch, ONNX Runtime, ExecuTorch, MediaPipe, and llama.cpp.

They each have their specific pros and cons, but generally, there’s 3 things that we care about as developers.

  1. portability (can we run it anywhere?)
  2. performance (is it fast enough to be usable?)
  3. flexibility (can you deviate from the text-only llama architecture?).

Sadly, a good rule of thumb that holds for most of these frameworks, is that the more performance & flexibility you get, the more painful a given framework to use.

For example, just running Pytorch CPU models is portable and flexible (any PyTorch code is valid), but it’s sloooooow.

On the contrary, ExecuTorch (which is the PyTorch native CPU inference framework) is portable and has good performance, but it is not very flexible. To be usable by Executorch (and most other frameworks), your model must define a computational graph, which is

  • Static
    • This means that the traced graph obtained by feeding example inputs to torch.export must be reusable with different inputs.
    • This rules out any data-dependent computation e.g.
      • token-dependent routing in Mixture of Experts.
      • recurrent form of RNNs i.e. for-loops over the sequence length1.
  • Use simple computational primitives
    • The torch operations used should be part of the predefined operator set, called “Edge Operators”, defined in the Edge Dialect.

As far as I know, most CPU inference frameworks (e.g ONNX or ggml) rely on computational graphs, which usually brings similar set of constraints.

But what if you really want to run your very custom model (which you swear is better than a Llama transformer)?

This is what will interest us today.

The usual way to bypass the above constraints is to register your custom/complicated computational primitive as a kernel. Then, you can abstract away your dirty machinery as a node in the graph.

This is a fairly easy process in Executorch, and we will assume that we have an easy way to register a C++ kernel in your specific framework, as I can’t be bothered to write about this :)

Now, we will focus on the process of writing a custom kernel with XNNPACK.

Why Choose XNNPACK for CPU Kernels?

Wait, you spent 3 minutes explaining why inference frameworks are constraining, and now you’re introducing another one?

Indeed, the duality of the software engineer is fascinating, but bear with me.

Let me state an obvious fact. In 2025, CUDA still dominates the GPU landscape for ML workloads. While competition is good, this also has the advantage that only few microarchitectures are relevant for CUDA kernels, i.e. Ampere SM80, Hopper SM90, and Blackwell SM100. Writing a new kernel for every new iteration is well worth the engineering time, given the performance benefits.

However, the CPU landscape has more variety, and we get what I call the ISA zoo; SSE2, AVX, AVX2, AVX-512 (x86-64); NEON, SVE, and SVE2 (ARM); RVV (RISC-V); and other older more niche extensions.

Personally, I don’t want to update my kernels when AVX{2n}\{2^n\} comes out.

Thus, ideally, we would like to have a level of abstraction that allows to write our custom operator without worrying about the above. XNNPACK is one possible solution, which we will explore today.

As said in the README, XNNPACK is a highly optimized solution for neural network inference on ARM, x86, WebAssembly, and RISC-V platforms.

Funnily enough, it also says that XNNPACK is not intended for direct use by deep learning practitioners; and that it provides low-level performance primitives for accelerating high-level machine learning frameworks, such as the ones mentioned previously.

You should not let yourself be scared!

If we’re here, it means that we need to get our hands dirty, and that we already know we can’t rely entirely on the high-level frameworks.

So what does XNNPACK expose?

The public API exposes two layers: (1) standalone micro-kernels (GEMM, DW-CONV, etc.) and (2) a small graph API that lets you stitch those kernels into a static subgraph.

Before diving in how to use XNNPACK, let’s prime the reader as to what it is good and bad at.

XNNPACK is good at expressing core ML operators such as GEMMs and convolutions.

Expressing static computation graphs e.g. SwiGLU is very easy, given the subgraph API. Wait wait, does that mean that we’re dealing with the same issues as previously (i.e. no data-dependent computation?). Indeed, we are!

But you should see XNNPACK as giving you building blocks (similar to CUTLASS) for your ML operators. You can build your whole model only within XNNPACK, but ultimately you’ll likely use it in combination with C++ tensor libraries such as xtensor.

More importantly, it gives good building blocks for quantized operators (int4, int8), which is ultimately what we’ll care about, because we’re not going to run our models in fp32 ;) It has extensive quantization support e.g. int4 or int8 activation and weight-only quantization.

Good int8 support is very important for good performance on CPU for two reasons (1) FLOPs are scarce (2) floating point support is not as good on CPU, compared to GPU (e.g. fp16 gets emulated2).

As a final point, a good indicator that you might need something beyond XNNPACK, is when your operator requires levels of details not easily expressed by a graph level of abstraction e.g. writing a flash attention-like kernel on CPU requires tiling matrix mutliplications carefully to allow for non-quadratic memory footprint.


Writing a SwiGLU kernel

In this example, I will walkthrough how to write an XNNPACK kernel for the SwiGLU (Swish-Gated Linear Unit)3.

The computation can be defined as output = W2 @ (SiLU(W1 @ input) * (W3 @ input)) where @ denotes matrix multiplication and * denotes element-wise multiplication.

In ML jargon, W1 is referred as the “gate projection”, W2 the “down projection”, and W3 “up projection”. I may reuse this terminology.

The code for this example can be found on Github, with accompanying instructions on how to build XNNPACK.

1) Initialization

We start by pulling in XNNPACK, sett a few compile-time constants, and initialize XNNPACK.

#include <stdio.h>
#include <math.h>
#include <xnnpack.h>
#include <vector>

// Use #define for compile-time constants to allow array initialization
#define INPUT_DIM  3
#define OUTPUT_DIM 2
#define INTER_DIM  4
#define BATCH_SIZE 1

int main(void) {
  // 1. Initialize XNNPACK
  if (xnn_initialize(NULL) != xnn_status_success) {
    fprintf(stderr, "Failed to initialize XNNPACK\n");
    return 1;
  }

  const size_t input_dim  = INPUT_DIM;
  const size_t output_dim = OUTPUT_DIM;
  const size_t inter_dim  = INTER_DIM;

Tip: for debugging, you can build with XNNPACK_LOG=Debug in your build env to get very verbose logs from the planner and runtime.

2) Defining the weights

For the demo we just fill weights with simple values. In a real kernel these would be passed in (and likely packed/quantized).

  // Weights are in row-major order. We will reuse w1_weights for w3.
  float w1_weight_data[INTER_DIM * INPUT_DIM];
  for (size_t i = 0; i < INTER_DIM; ++i) {
    for (size_t j = 0; j < INPUT_DIM; ++j) {
      w1_weight_data[i * INPUT_DIM + j] =
        static_cast<float>(i * INPUT_DIM + j + 1) / (INTER_DIM * INPUT_DIM);
    }
  }

  float w2_weight_data[OUTPUT_DIM * INTER_DIM];
  for (size_t i = 0; i < OUTPUT_DIM; ++i) {
    for (size_t j = 0; j < INTER_DIM; ++j) {
      w2_weight_data[i * INTER_DIM + j] =
        static_cast<float>(i * INTER_DIM + j + 1) / (OUTPUT_DIM * INTER_DIM);
    }
  }

We’ll use w1_weight_data both for the W1 (gate) and W3 (up) projections to keep the example short.

3) Creating an XNN subgraph

All operators and tensors are defined into a subgraph.

  xnn_subgraph_t subgraph = NULL;
  enum xnn_status status = xnn_create_subgraph(
    /*external_value_ids=*/2,  // we have 2 external values: input and output
    /*flags=*/0,
    &subgraph);
  if (status != xnn_status_success) {
    fprintf(stderr, "xnn_create_subgraph failed: %d\n", status);
    return 1;
  }

Always check the returned xnn_status and bail early; it saves a lot of head-scratching later.

4) Defining tensors

For our computation, we will need the following tensors:

  • Input [BATCH_SIZE, INPUT_DIM]
  • Output [BATCH_SIZE, OUTPUT_DIM]
  • Internal intermediates for matmuls and elementwise ops
  • Static weight tensors for W1/W3 and W2

External tensors

We declare two external tensors (input and output). External tensors are how we feed data into and out of the subgraph.

  // Define input tensor
  uint32_t input_id;
  {
    std::vector<size_t> input_dims = {1, INPUT_DIM};
    status = xnn_define_tensor_value(
      subgraph,
      xnn_datatype_fp32,
      /*num_dims=*/input_dims.size(),
      /*dims=*/input_dims.data(),
      /*data=*/nullptr,                // provided at setup
      /*external_id=*/0,               // external slot 0
      /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT,
      &input_id);
    if (status != xnn_status_success) { /* ... */ }

  }


  // Define output tensor
  uint32_t output_id;
  {
    std::vector<size_t> output_dims = {1, OUTPUT_DIM};
    status = xnn_define_tensor_value(
      subgraph,
      xnn_datatype_fp32,
      output_dims.size(),
      output_dims.data(),
      /*data=*/nullptr,                // provided at setup
      /*external_id=*/1,               // external slot 1
      /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
      &output_id);
    if (status != xnn_status_success) { /* ... */ }

  }

Key points:

  • Set data=nullptr for external tensors (data is bound at runtime)
  • external_id must be unique and match what you’ll use later
  • The batch dimension can be reshaped later, but we need a placeholder value.
  • They are defined using the flag XNN_VALUE_FLAG_EXTERNAL_INPUT or XNN_VALUE_FLAG_EXTERNAL_OUTPUT.

Static weights (internal)

We now allocate the tensors to hold the projection weights within the subgraph.

  // W1 (and W3) : [INTER_DIM, INPUT_DIM]
  uint32_t w1_weight_id;
  {
    std::vector<size_t> w1_weight_dims = {INTER_DIM, INPUT_DIM};
    status = xnn_define_tensor_value(
      subgraph,
      xnn_datatype_fp32,
      w1_weight_dims.size(),
      w1_weight_dims.data(),
      /*data=*/w1_weight_data,
      /*external_id=*/XNN_INVALID_VALUE_ID,
      /*flags=*/0,
      &w1_weight_id);
    if (status != xnn_status_success) { /* ... */ }
  }

// W projection: w3 @ input
  uint32_t w3_weight_id;
  {
    std::vector<size_t> w1_weight_dims = {INTER_DIM, INPUT_DIM};
    status = xnn_define_tensor_value(
      subgraph,
      xnn_datatype_fp32,
      w1_weight_dims.size(),
      w1_weight_dims.data(),
      /*data=*/w1_weight_data, // Reusing w1 weights for demonstration
      /*external_id=*/XNN_INVALID_VALUE_ID,
      /*flags=*/0,
      &w3_weight_id);
    if (status != xnn_status_success) { /* ... */ }
  }

  // W2 : [OUTPUT_DIM, INTER_DIM]
  uint32_t w2_weight_id;
  {
    std::vector<size_t> w2_weight_dims = {OUTPUT_DIM, INTER_DIM};
    status = xnn_define_tensor_value(
      subgraph,
      xnn_datatype_fp32,
      w2_weight_dims.size(),
      w2_weight_dims.data(),
      /*data=*/w2_weight_data,
      /*external_id=*/XNN_INVALID_VALUE_ID,
      /*flags=*/0,
      &w2_weight_id);
    if (status != xnn_status_success) { /* ... */ }
  }

Internal intermediates

Then, we need to allocate the tensors that will hold intermediate steps of the computation.

As a reminder, the computation is W2(SiLU(W1x)W3x)W_2(\operatorname{SiLU}(W_1x)*W3_x).

For intermediate steps, we’ll need:

  • gate_output_id for W1xW_1 x
  • sigmoid_output_id for σ(W1x)\sigma(W_1x)
  • silu_output_id for SiLU(W1x)=σ(W1x)W1x\operatorname{SiLU}(W_1x) = \sigma(W_1x) * W_1x
  • up_output_id for W3xW_3x
  • gated_intermediate_output_id for SiLU(W1x)W3x\operatorname{SiLU}(W_1x) * W_3x

Each is a 2D [BATCH_SIZE, INTER_DIM] tensor:

  uint32_t gate_output_id, sigmoid_output_id, silu_output_id,
           up_output_id, gated_intermediate_output_id;

  {
    std::vector<size_t> inter_dims = {BATCH_SIZE, INTER_DIM};

    // gate_output_id
    status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32,
      inter_dims.size(), inter_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
      &gate_output_id);
    if (status != xnn_status_success) { /* ... */ }

    // sigmoid_output_id
    status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32,
      inter_dims.size(), inter_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
      &sigmoid_output_id);
    if (status != xnn_status_success) { /* ... */ }

    // silu_output_id
    status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32,
      inter_dims.size(), inter_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
      &silu_output_id);
    if (status != xnn_status_success) { /* ... */ }

    // up_output_id
    status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32,
      inter_dims.size(), inter_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
      &up_output_id);
    if (status != xnn_status_success) { /* ... */ }

    // gated_intermediate_output_id
    status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32,
      inter_dims.size(), inter_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
      &gated_intermediate_output_id);
    if (status != xnn_status_success) { /* ... */ }
  }

About fp16 → fp32 upcasting: if your inputs/weights are fp16, insert a convert/unary node to convert to fp32 for ops that only exist in fp32. In our example, we assume everything is in fp32.

5) Defining the operators

Now we stitch the nodes together to express the computation:

  // 1) Gate projection: W1 @ input -> gate_output_id
  status = xnn_define_fully_connected(
    subgraph,
    /*output_min=*/-INFINITY,
    /*output_max=*/INFINITY,
    /*input_id=*/input_id,
    /*filter_id=*/w1_weight_id,
    /*bias_id=*/XNN_INVALID_VALUE_ID,  // No bias
    /*output_id=*/gate_output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

  // 2) Sigmoid(gate)
  status = xnn_define_unary(
    subgraph,
    xnn_unary_sigmoid,
    /*params=*/nullptr,
    /*input_id=*/gate_output_id,
    /*output_id=*/sigmoid_output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

  // 3) SiLU = gate * sigmoid(gate)
  status = xnn_define_multiply2(
    subgraph,
    /*output_min=*/-INFINITY,
    /*output_max=*/INFINITY,
    /*input1_id=*/gate_output_id,
    /*input2_id=*/sigmoid_output_id,
    /*output_id=*/silu_output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

  // 4) Up projection: W3 @ input
  status = xnn_define_fully_connected(
    subgraph,
    /*output_min=*/-INFINITY,
    /*output_max=*/INFINITY,
    /*input_id=*/input_id,
    /*filter_id=*/w3_weight_id,   
    /*bias_id=*/XNN_INVALID_VALUE_ID,
    /*output_id=*/up_output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

  // 5) Gate multiply: silu * up -> gated_intermediate
  status = xnn_define_multiply2(
    subgraph,
    /*output_min=*/-INFINITY,
    /*output_max=*/INFINITY,
    /*input1_id=*/silu_output_id,
    /*input2_id=*/up_output_id,
    /*output_id=*/gated_intermediate_output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

  // 6) Down projection: W2 @ gated_intermediate -> output
  status = xnn_define_fully_connected(
    subgraph,
    /*output_min=*/-INFINITY,
    /*output_max=*/INFINITY,
    /*input_id=*/gated_intermediate_output_id,
    /*filter_id=*/w2_weight_id,
    /*bias_id=*/XNN_INVALID_VALUE_ID,
    /*output_id=*/output_id,
    /*flags=*/0);
  if (status != xnn_status_success) { /* ... */ }

Fusion note: XNNPACK’s planner will fuse where possible (e.g., elementwise with adjacent ops) when creating the runtime. You don’t need to hand-fuse SiLU; defining sigmoid + multiply is fine.

6) Creating the runtime

We now create the runtime, and its associated workspace (mempool). The workspace lets XNNPACK manage scratch and persistent buffers (i.e. activation memory) across runs. Note that we can also pass in a threadpool to explicitly manage the number of threads allocated to your runtime.

  // 6. Create a workspace and runtime for this SwiGLU
  xnn_workspace_t xnn_workspace;
  status = xnn_create_workspace(&xnn_workspace);
  if (status != xnn_status_success) { /* ... */ }


  xnn_runtime_t runtime;
  status = xnn_create_runtime_v4(
    subgraph,
    /*weights_cache=*/nullptr,
    /*workspace=*/xnn_workspace,
    /*threadpool=*/NULL,
    /*flags=*/0,
    &runtime);
  if (status != xnn_status_success) { /* ... */ }

What happens here?

  • Operator fusion: XNNPACK identifies fusible patterns (like sigmoid+multiply → SiLU) and merges them into single kernels
  • Code generation: Selects the optimal microkernel for your CPU architecture (e.g., AVX512 on x86, NEON on ARM)
  • Memory planning: Buffers are allocated out of the workspace; lifetimes are scheduled to reuse memory.

Managing threadpools and workspaces will become more important as you write more complex kernels that may involve running multiple XNNPACK runtimes at the same time. Indeed, if two runtimes run sequentially, we can reuse the same workspace to reduce peak memory usage, but this is not possible if they are run in parallel (two parallel processes should not write and read in the same memory space). In that case, you need to explicitly allocate two workspaces. Similarly, you may not want two runtimes to share the same threadpool for performance reasons.

The reason I’m discussing this is that, when writing your kernel, you can usually choose where to realize your algorithm on the spectrum between sequential and fully parallel, which tends to have a direct relationship with peak memory (more sequential usually means lower peak memory).

7) Invoking the kernel

We set the external values defined previously in our graph to the actual pointers for your input/output at run time. Additionally, we inform the runtime of the expected batch size of our computation, an operation called reshape in XNNPACK.

Then, finally, we can setup the runtime, and invoke the kernel.

  // Example input/output buffers
  float input_data[BATCH_SIZE * INPUT_DIM]  = {1.0f, 2.0f, 3.0f};
  float output_data[BATCH_SIZE * OUTPUT_DIM];

  // Bind external values in the order of external IDs we declared (0=input, 1=output)
  std::vector<xnn_external_value> external_values(2);
  external_values[0].id   = 0; // input external id
  external_values[0].data = input_data;
  external_values[1].id   = 1; // output external id
  external_values[1].data = output_data;

  // Reshape (if dimensions changed since definition)
  {
    std::vector<size_t> input_dims  = {BATCH_SIZE, INPUT_DIM};
    std::vector<size_t> output_dims = {BATCH_SIZE, OUTPUT_DIM};

    status = xnn_reshape_external_value(runtime, 0, input_dims.size(),  input_dims.data());
    if (status != xnn_status_success) { /* ... */ }

    status = xnn_reshape_external_value(runtime, 1, output_dims.size(), output_dims.data());
  	if (status != xnn_status_success) { /* ... */ }

    status = xnn_reshape_runtime(runtime);
  	if (status != xnn_status_success) { /* ... */ }

  }

  // Setup + run
  status = xnn_setup_runtime_v2(runtime, external_values.size(), external_values.data());
  if (status != xnn_status_success) { /* ... */ }


  status = xnn_invoke_runtime(runtime);
  if (status != xnn_status_success) { /* ... */ }

  // Inspect result
  printf("Output: [%f, %f]\\n", output_data[0], output_data[1]);

Reshaping caveat: reshaping can grow memory but won’t necessarily shrink already allocated workspace if your batch dimension lowers.

8) Cleanup

Finally, after running your kernel, we can cleanup resources.

  xnn_delete_runtime(runtime);
  xnn_release_workspace(xnn_workspace);
  xnn_delete_subgraph(subgraph);
  xnn_deinitialize();
  return 0;
}

Useful tips (from the trenches)

  • Check every status. After every xnn_define_* and runtime call, assert xnn_status_success.
  • Turn on logs: XNNPACK_LOG=Debug (or set XNN_LOG_LEVEL=4 depending on build) is invaluable to see chosen kernels, fusions, and memory plans.
  • Study prior art: complex subgraphs (e.g., quantized attention) in XNNPACK’s source like attention.cc, are great templates for handling edge cases, quantization params, and dynamic shapes.

Quantized kernel version

We now know how to define a SwiGLU with fp32 weights and fp32 inputs in XNNPACK!

The logical next step is to write a quantized version supporting int4 weights and int8 activations. Suprisingly, this not too hard.

Indeed, xnn_define_fully_connected automatically specializes to the correct implementation depending on the datatypes of the input and weight. You can see the supported combinations in fully-connected.c.

For the rest of the example, we will assume that we’re targeting xnn_create_fully_connected_nc_qd8_f32_qb4w which is a GEMM with int4 blockwise quantized weights, dynamically quantized int8 input, and fp32 output.

Generally, to write a quantized kernel, I would advise to read the function signatures in fully-connected.c extensively to understand the different available quantization formats, e.g. do I need to provide bf16 or fp16 scales, is it blockwise or channelwise quantization. Not adhering to the expected format was one of the most common mistakes I encountered.

To extend our SwiGLU kernel, for each of our linear projections, we will need to:

Defining int4 weights

We define the weight matrices with xnn_define_blockwise_quantized_tensor_value.

    xnn_status status = xnn_define_blockwise_quantized_tensor_value(
        /*subgraph=*/subgraph,
        /*datatype=*/xnn_datatype_qbint4,
        /*zeropoint=*/ 8,
        /*scale=*/static_cast<const uint16_t*>(weight_scale_data), 
        /*num_dims=*/weight_dims.size(),
        /*channel_dim=*/ 0, 
        /*block_size=*/block_size,
        /*dims=*/weight_dims.data(),
        /*data=*/weight_data,
        /*external_id=*/XNN_INVALID_VALUE_ID,
        /*flags=*/ 0,
        /*id_out=*/&weight_id);

You will need to pay attention to the exact format expected from the signature. Here, we’re defined a symmetrically quantized tensor. XNNPACK expects packed 4-bit weights in uint8 format (thus, why the zero-point is 8) when using 4 bit weights.

To help you understand packing, here’s a function in PyTorch to illustrate how to pack unpacked int4 weights (stored in int8, because int4 isn’t supported in PyTorch) into uint8 packed weights.

def pack_4_bit_weights_for_8da4w(weights: torch.Tensor) -> torch.Tensor:
    assert weights.ndim == 2
    output_features, input_features = weights.shape

    weights = weights.to(torch.uint8) + 8
    weights = weights.contiguous().view(-1)
    weights = (weights[1::2] << 4 | weights[::2]).view(output_features, int(input_features / 2))

    return weights

Quantizing the input to int8

We need to define a conversion step to dynamically quantized the fp32/bf16 input into int8. For example to quantized input_id, we can write

    uint32_t quantized_input_id;
    {

    std::vector<size_t> input_dims = {1, hidden_dim}; // Should match input_id dims
    status = xnn_define_dynamically_quantized_tensor_value(
        xnn_subgraph.subgraph, xnn_datatype_qdint8, input_dims.size(),
        /*num_non_batch_dims=*/1, input_dims.data(), XNN_INVALID_VALUE_ID,
        /*flags=*/0, &quantized_input_id);

    // Define dynamic conversion from float/half to qdint8
    status = xnn_define_convert(
        /*subgraph=*/xnn_subgraph.subgraph,
        /*input_id=*/input_id,
        /*output_id=*/quantized_input_id,
        /*flags=*/ 0
        );
    }

Conclusion

Voilà, now you know how to write your own XNNPACK kernel!




Footnotes

  1. This is technically possible if higher order ops (such as the while loop) are accepted as traceable node.

  2. On most x86 CPUs without native FP16 support (pre-AVX512-FP16), FP16 operations are computed in FP32 with conversions, causing 2-4x slowdowns.

  3. SwiGLU: Introduced in “GLU Variants Improve Transformer” by Noam Shazeer, arXiv:2002.05202, 2020.

Reference

Cite this post using the BibTeX snippet below.

@misc{benoit2025xnnpackkernel,
  author       = {Benoit, Harold},
  title        = {Writing CPU ML Kernels with XNNPACK},
  year         = {2025},
  howpublished = {\url{https://haroldbenoit.com/blog/xnnpack_kernel/}},
  note         = {Blog post},
  url          = {https://haroldbenoit.com/blog/xnnpack_kernel/},
  urldate      = {2025-11-18}
}

Tags

  • cpu
  • kernel