Metal Compute Shader Patterns for Sparse ML Training
There is essentially zero academic literature on writing ML training kernels for Apple Metal. Everything is CUDA. Here's what we learned building a sparse training engine from scratch on Apple Silicon.
Overview
Spark is a neural network training engine written from scratch in C++ with Metal compute shaders for GPU acceleration on Apple Silicon. The network has a sparse topology — not all neurons connect to all others — which means standard dense matrix multiplication libraries (BLAS, MPS, cuBLAS) are not applicable. Every kernel had to be written by hand.
This report covers the practical patterns that emerged for sparse ML training on Metal.
1. CSR-Based Sparse Forward and Backward
The network's connectivity is stored in Compressed Sparse Row (CSR) format: for each target neuron, a contiguous array of source neuron indices and corresponding weights. This gives O(1) lookup for "which neurons feed into neuron X?" — exactly the access pattern needed for the forward pass.
Forward pass: Each thread processes one (target neuron, batch element) pair. It iterates over the CSR row for that target, reading source activations and accumulating the weighted sum. The access pattern is sequential over edges and strided over the batch dimension for source activations.
Backward pass: Weight gradients require iterating the same CSR structure but accumulating into gradient buffers. Since multiple batch elements contribute to the same weight gradient, we use atomic_fetch_add_explicit with memory_order_relaxed. Input gradients require a transpose traversal — we maintain a separate CSR indexed by source neuron for this purpose.
Key insight: Maintaining two CSR representations (by-target and by-source) doubles memory for index arrays but avoids the need for a sparse transpose operation at runtime.
2. Memory Barriers for Layer-Sequential Dispatch
In a layered network, layer L+1 depends on the outputs of layer L. On CUDA, this is typically handled with cudaDeviceSynchronize() or stream dependencies between kernel launches. Metal has a different model.
Metal compute commands within a single command buffer execute in order, but the GPU may overlap or reorder memory operations. The correct pattern for layer-sequential dispatch is:
- One dispatch per layer within the same command encoder
memoryBarrier(scope: .device)between dispatches to ensure layer L's writes are visible to layer L+1's reads- Single command buffer commit at the end of the forward pass, not per-layer
This pattern gives the Metal driver maximum flexibility to schedule threadgroups while preserving data dependencies. The alternative — separate command buffers per layer — adds substantial overhead from repeated buffer commits and GPU scheduling round-trips.
Pitfall: Using threadgroup scope instead of device scope for the memory barrier silently produces incorrect results. Layer L+1 will read stale values from device memory. This is not caught by Metal validation layers.
3. Unified Memory: Zero-Copy Gradient Readback
Apple Silicon's unified memory architecture means CPU and GPU share the same physical DRAM. Metal buffers created with MTLResourceStorageModeShared are accessible from both sides without explicit copy commands.
For ML training, this has a concrete advantage: gradient readback is free. After the backward pass computes gradients on GPU, the CPU can read them directly from the same buffer — no cudaMemcpy equivalent needed. The only requirement is a waitUntilCompleted() to ensure the GPU has finished writing.
We use this for:
- Gradient statistics on CPU — reading per-neuron gradient magnitudes directly from GPU buffers without a copy
- Validation loss readback — the cross-entropy loss computed on GPU is read directly by the CPU training loop
- Weight updates — the optimizer runs on GPU but the CPU can inspect updated weights immediately
Compared to CUDA: On discrete GPUs, these operations each require explicit host-device transfers (or pinned memory with async streams). On Apple Silicon, they are pointer dereferences. For a training loop that reads gradient statistics every step, this eliminates a meaningful source of latency.
4. Fuse Everything You Can
Metal dispatch overhead is non-trivial for small kernels. Each dispatchThreadgroups call carries fixed cost from command encoding and GPU scheduling. For sparse networks with modest layer sizes, this overhead can dominate compute time.
Our approach: fuse aggressively. The backward pass computes weight gradients, bias gradients, and input gradients in a single dispatch per layer rather than three separate dispatches. The activation function and its derivative are inlined into the forward and backward kernels respectively.
We empirically validated this with an A/B experiment: splitting the backward pass into two separate dispatches (one for weight gradients, one for input gradients) resulted in a 2.8x slower backward pass — even though the split version eliminated all atomic operations. The dispatch overhead and cache thrashing from the restructured memory access pattern more than offset the atomic savings. (See: Why Atomic-Free Sparse Backward Passes Are Slower on Metal)
5. Thread Organization for Sparse Workloads
Dense ML kernels map neatly onto 2D grids (neurons x batch). Sparse kernels don't — each neuron has a different number of incoming edges, so the work per thread varies wildly.
Our thread mapping:
- 1D dispatch over (neuron * batch_size) for the forward pass. Each thread processes one neuron for one batch element, iterating over that neuron's CSR row.
- Threadgroup size of 256 for most kernels. Apple Silicon's SIMD width is 32; 256 gives 8 SIMD groups per threadgroup, which provides good occupancy without excessive register pressure.
- No threadgroup memory (shared memory in CUDA terms) for the core forward/backward kernels. The sparse access patterns don't benefit from tiling, and the unified memory architecture means device memory access is already fast.
Load imbalance: Neurons with many incoming edges take longer than neurons with few. We don't address this with work-stealing or dynamic parallelism — Metal doesn't support dynamic dispatch from within a kernel. Instead, the layer-sequential approach naturally groups neurons of similar depth, and within a layer the variation is manageable.
Performance Profile
Step time breakdown at ~60K edges (the approximate topology size at the end of a typical 10-minute training run):
| Phase | Time (ms) | % of Step |
|---|---|---|
| Forward | 13.7 | 37% |
| Backward | 17.3 | 47% |
| Cross-entropy | 2.6 | 7% |
| Optimizer (AdamW) | 0.5 | 1% |
| Gradient readback | 0.1 | <1% |
| Validation | 1.6 | 4% |
| Total | ~37 | 100% |
Data from a 10-minute training run on M1 Pro. Gradient readback is effectively free due to unified memory.
The backward pass dominates at 47% of step time, which is typical — it computes weight gradients, bias gradients, and input gradients in one fused dispatch. The forward pass at 37% is proportional. Notably, gradient readback is under 1% — on a discrete GPU with PCIe transfer, this would be significantly higher.
Lessons for Metal ML Developers
- Atomics are cheap. On unified memory,
atomic_fetch_addwith relaxed ordering is fast. Don't restructure your kernels to avoid them unless you have profiling data showing contention. - Fuse dispatches aggressively. Metal dispatch overhead matters. Two small dispatches are worse than one larger dispatch, even if the larger one has some atomic contention.
- Use device-scope memory barriers. Between dependent dispatches in the same encoder,
memoryBarrier(scope: .device)is the correct synchronization primitive. Threadgroup scope is wrong and silent. - Exploit unified memory. Read gradients, loss values, and statistics from the CPU without explicit transfers. Design your data layout so the CPU can read what it needs directly from GPU buffers.
- Skip threadgroup memory for sparse kernels. The tiling strategies that help dense GEMM kernels don't apply when each thread traverses a different-length edge list. Device memory with good coalescing is sufficient.