Custom CUDA Quantisation

A hand-written, register-tiled quantization engine that beats industry-standard libraries on consumer hardware. Built from scratch in C++, CUDA, and Python.

Quantization project cover image

The Memory Wall

Why LLMs starve on consumer GPUs

Large Language Models are massive. A standard GPT-2 Large eats nearly 3GB of VRAM. On a consumer card like my GTX 1660 Ti, memory bandwidth is the bottleneck, not compute. Industry libraries like bitsandbytes are great generalists, but they carry overhead for safety and compatibility. I wanted to see if a specialist engine—hand-tuned for this specific hardware—could outperform them.

The Dark Arts of Optimization

Or: How I learned to stop worrying and love pointers

Act I: The Tiled Kernel

Naive matrix multiplication is slow because it hammers Global Memory. My kernel uses Register Tiling. Each thread calculates a 4x4 sub-matrix. We cooperatively load data into Shared Memory (the GPU's L1 cache) and perform integer accumulation there. This drastically increases arithmetic intensity.

// From qgemm_tiled.cu
__global__ void qgemm_kernel(...) {
    // Static shared arrays for Block Tiling
    __shared__ int8_t  Wq_tile[TILE_SIZE][TILE_SIZE];
    __shared__ int32_t Xq_tile[TILE_SIZE][TILE_SIZE];

    // Cooperative load: threads load chunks of W and X
    for (int t = 0; t < numTiles; ++t) {
        // ... loading logic ...
        Wq_tile[tile_r][tile_c] = Wq[Wq_row * K + Wq_col];
        
        // Zero-point subtraction happens ON LOAD
        int32_t xval = static_cast<int32_t>(Xq[Xq_row * N + Xq_col]);
        Xq_tile[tile_r][tile_c] = xval - Zx[Xq_col];
        
        __syncthreads(); // Wait for VRAM -> Shared Mem

        // Compute dot products using Registers (Reg Tiling)
        #pragma unroll
        for (int k = 0; k < TILE_SIZE; ++k) {
            #pragma unroll
            for (int i = 0; i < REG_TILE_M; ++i) {
                // Integer accumulation!
                acc[i][j] += static_cast<int32_t>(w_reg[i]) * Xq_tile[k][thread_col_in_tile + j];
            }
        }
    }
    // ... Output scaling fused at the end ...
}
# From calibration.py
# The "Cheat Code": Pre-calculating ranges to avoid runtime overhead
calibration_ranges = {}

# Calculate min/max quantile values offline
min_quantile = (1.0 - QUANTILE) / 2.0
max_quantile = 1.0 - min_quantile

for name, tensor_list in tqdm(captured_activations.items()):
    all_activations = torch.cat(flattened_tensors, dim=0)
    
    # We find the perfect clipping range before inference even starts
    q_min = torch.quantile(all_activations, min_quantile, dim=0)
    q_max = torch.quantile(all_activations, max_quantile, dim=0)
    
    calibration_ranges[name] = {'min': q_min.cpu(), 'max': q_max.cpu()}

Act II: Static Calibration

Standard libraries calculate min/max values at runtime to handle outliers, creating overhead. I took a "cheat code" approach: Static Calibration. By running samples through the model offline, I pre-calculate clipping ranges. My kernel doesn't search for limits; it just knows them.

Act III: The Hybrid Strategy

Pure INT8 quantization can destroy accuracy due to error propagation. I implemented a "Moneyball" strategy. Robust layers run on my custom INT8 kernel for speed. Sensitive layers use QuantLinearDequant, which stores weights in INT8 (saving memory) but computes in FP16 to reset error accumulation.

 # From qlinear.py
class QuantLinearDequant(nn.Module):
    """
    Memory-efficient safety net.
    Stores weights in INT8 but dequantizes on-the-fly to FP16.
    Acts as a 'circuit breaker' for error propagation.
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Dequantize on-the-fly
        W_dequant = qgemm_cuda.dequantize_rowwise(self.Wq, self.row_scales)

        # Standard High-Precision MatMul
        if W_dequant.shape[0] == x.shape[-1]:
             Y = torch.matmul(x, W_dequant.to(x.dtype))
        else:
             Y = torch.matmul(x, W_dequant.t().to(x.dtype))
             
        if self.bias_param is not None:
            Y = Y + self.bias_param.to(x.dtype)
        return Y

The Receipts

David vs. Goliath (and Goliath lost)

I benchmarked my custom engine against the baseline FP16 model andbitsandbytes (BnB). The results show that a specialized, hand-tuned engine can outperform generalist libraries on specific hardware.

Throughput

2401 tok/s

Beating Bitsandbytes (2270 tok/s) and Baseline (2082 tok/s).

VRAM Usage

1.00 GB

Slashed from 2.99 GB baseline. Fits comfortably on older GPUs.

Perplexity

17.76

Virtually identical to baseline (17.74). No loss in model intelligence.

Why It Matters

It's not just about the speed

  • Manual Memory Management: Managing __shared__ memory and barriers manually is terrifying but powerful.
  • Quantization is Architecture: It's not just compressing numbers; it's architecting a pipeline that knows when to compress and when to preserve precision.
  • Fused Operations: Fusing dequantization and scaling directly into the GEMM kernel saved massive memory bandwidth.

Next Steps

Where do we go from here?

This project proved that custom kernels can squeeze extra performance out of consumer hardware. The next step is to explore Double Buffering to pre-fetch tiles while computing, and potentially implementing per-channel quantization for even higher accuracy on more complex models.