Where are they defined

  1. For fp32 GEMM (input, output, weights): - /src/f32-gemm/ contains implementations for different architectures - Both assembly files (.S.in) and C templates (.c.in) - Examples: 1x8-aarch64-neonfma-ld64.S.in, avx-broadcast.c.in, neon-ld64.c.in
  2. For fp16 GEMM (input, output, weights): - /src/f16-gemm/ directory - ARM implementations: 1x16-aarch64-neonfp16arith-ld32.S.in - x86 implementations: avx2-broadcast.c.in, avx512fp16-broadcast.c.in
  3. For fp32 with int4 packed weights: - /src/f32-qc4w-gemm/ - standard 4-bit quantized - /src/qp8-f32-qc4w-gemm/ - 4-bit weights with 8-bit activation - /src/qp8-f32-qb4w-gemm/ - block-wise 4-bit quantized weights

Function declarations are in src/xnnpack/gemm.h with all supported variants across architectures.

Choosing which GEMM

  • This happens at the reshape_fully_connected call src/operators/fully-connected-nc.c
// Compute the optimal tile size for this GEMM.
 
const size_t nc = xnn_gemm_best_tile_size(
 
/*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels,
 
/*m_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.a_stride,
 
/*n_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.w_stride,
 
/*cm_stride=*/
 
dynamic_fully_connected_op->context.gemm.gemm.gemm.cm_stride,
 
/*cn_stride=*/1 << log2_output_element_size, mr, nr,
 
/*num_threads=*/pthreadpool_get_threads_count(threadpool));

Calling the micro-kernels

The GEMMBenchmark function in bench/f32-gemm.cc is the core function that directly executes a specific microkernel with given workload dimensions. Let me break down what it’s doing:

  • It takes a specific microkernel function pointer (gemm) as its first parameter

  • It sets up the workload dimensions from state.range(0), state.range(1), and state.range(2) (which are M, N, and K)

  • It prepares the input data (matrices A and B) with random values

  • It packs the weights using xnn_pack_f32_gemm_goi_w for optimal memory layout

  • It sets up the minmax parameters

  • In the benchmark loop, it directly calls the microkernel function:

    gemm(
    
      mb, nc, kc * sizeof(float),
    
      a.data() + m * kc, kc * sizeof(float),
    
      w.data() + buffer_index * nc_stride * (kc_stride + 1),
    
      c.data() + (buffer_index * mc + m) * nc, nc * sizeof(float), nr * sizeof(float),
    
      &params);

This is the exact call that executes the computation for a specific microkernel with the given workload dimensions. The function processes the matrix in blocks of size mr × nc, where mr is the number of rows processed by the microkernel in one call.

The parameters to the microkernel are:

  • mb: Number of rows to process (up to mr)

  • nc: Number of columns to process

  • kc * sizeof(float): Size of K dimension in bytes

  • a.data() + m * kc: Pointer to the current block of matrix A

  • kc * sizeof(float): Stride for matrix A

  • w.data() + buffer_index…: Pointer to the packed weights

  • c.data() + …: Pointer to the output matrix C

  • nc * sizeof(float): Row stride for matrix C

  • nr * sizeof(float): Column stride for matrix C

  • &params: Pointer to the minmax parameters