Where are they defined
- For fp32 GEMM (input, output, weights): - /src/f32-gemm/ contains implementations for different architectures - Both assembly files (.S.in) and C templates (.c.in) - Examples: 1x8-aarch64-neonfma-ld64.S.in, avx-broadcast.c.in, neon-ld64.c.in
- For fp16 GEMM (input, output, weights): - /src/f16-gemm/ directory - ARM implementations: 1x16-aarch64-neonfp16arith-ld32.S.in - x86 implementations: avx2-broadcast.c.in, avx512fp16-broadcast.c.in
- For fp32 with int4 packed weights: - /src/f32-qc4w-gemm/ - standard 4-bit quantized - /src/qp8-f32-qc4w-gemm/ - 4-bit weights with 8-bit activation - /src/qp8-f32-qb4w-gemm/ - block-wise 4-bit quantized weights
Function declarations are in src/xnnpack/gemm.h with all supported variants across architectures.
Choosing which GEMM
- This happens at the
reshape_fully_connected
callsrc/operators/fully-connected-nc.c
// Compute the optimal tile size for this GEMM.
const size_t nc = xnn_gemm_best_tile_size(
/*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels,
/*m_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.a_stride,
/*n_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.w_stride,
/*cm_stride=*/
dynamic_fully_connected_op->context.gemm.gemm.gemm.cm_stride,
/*cn_stride=*/1 << log2_output_element_size, mr, nr,
/*num_threads=*/pthreadpool_get_threads_count(threadpool));
Calling the micro-kernels
The GEMMBenchmark
function in bench/f32-gemm.cc
is the core function that directly executes a specific microkernel with given workload dimensions. Let me break down what it’s doing:
-
It takes a specific microkernel function pointer (gemm) as its first parameter
-
It sets up the workload dimensions from state.range(0), state.range(1), and state.range(2) (which are M, N, and K)
-
It prepares the input data (matrices A and B) with random values
-
It packs the weights using xnn_pack_f32_gemm_goi_w for optimal memory layout
-
It sets up the minmax parameters
-
In the benchmark loop, it directly calls the microkernel function:
gemm(
mb, nc, kc * sizeof(float),
a.data() + m * kc, kc * sizeof(float),
w.data() + buffer_index * nc_stride * (kc_stride + 1),
c.data() + (buffer_index * mc + m) * nc, nc * sizeof(float), nr * sizeof(float),
¶ms);
This is the exact call that executes the computation for a specific microkernel with the given workload dimensions. The function processes the matrix in blocks of size mr × nc, where mr is the number of rows processed by the microkernel in one call.
The parameters to the microkernel are:
-
mb: Number of rows to process (up to mr)
-
nc: Number of columns to process
-
kc * sizeof(float): Size of K dimension in bytes
-
a.data() + m * kc: Pointer to the current block of matrix A
-
kc * sizeof(float): Stride for matrix A
-
w.data() + buffer_index…: Pointer to the packed weights
-
c.data() + …: Pointer to the output matrix C
-
nc * sizeof(float): Row stride for matrix C
-
nr * sizeof(float): Column stride for matrix C
-
¶ms: Pointer to the minmax parameters