lmkit-go vs lmkit: pure Go on XLA against PyTorch, same 8GB card

Built lmkit-go to answer one question: can you train a real language model from scratch in pure Go, on XLA, and not give up much to PyTorch on the way? The honest answer needed a head-to-head, so I ran both trainers to a full 2B-token budget on the same card, same model, same data, same hyperparameters, same seed, and watched the numbers.

TLDR: on an 8GB RTX 3070 Ti, lmkit-go (Go on XLA via GoMLX) trains the 100M Llama at 85% of the PyTorch reference’s throughput, using roughly 1.7x the VRAM, and lands on the same loss curve. That throughput gap used to be 26x. Closing most of it was a side-quest into the Go ML stack itself, which is the more interesting story.

The setup

The point of the exercise is a matched run. Same architecture, same optimizer, same schedule, same tokenized shards, same seed. If anything drifts, the comparison is worthless, so the config is locked to lmkit’s lm-100m-en:

model100M Llama: hidden 768, 12 layers, 12 heads, 4 KV heads (GQA 3:1), head_dim 64
SwiGLU ffn 2048, vocab 32k, seq 2048, RoPE base 10k, tied embeddings
optimizerAdamW, lr 4e-4, betas 0.9/0.95, weight decay 0.1, grad clip 1.0
scheduleWSD, 1000-step warmup, constant trunk, no decay
precisionbf16 compute, fp32-internal norms
batchmicro 2 x grad-accum 32 = 131,072 tokens/step
budgetChinchilla, ~2B tokens = 15,406 steps
seed1337, same shards for both arms

Both arms read the exact same tokenized .bin shards, train and val, so even the validation loss is comparable rather than just close.

The numbers

Both arms ran the full 15,406 steps to ~2.02B tokens.

lmkit-go (Go / XLA)lmkit (PyTorch)
throughput28.8k tok/s33.8k tok/s
VRAM @ batch 26.5 GB3.9 GB
final val loss (14k)2.382.33
run15,406 steps, 2.02B tok15,406 steps, 2.02B tok

Go-on-XLA is at 85% of PyTorch’s throughput on identical hardware, and pays for it in memory: 6.5GB against 3.9GB at the same batch. Both fit the 8GB card at batch 2, both OOM at batch 4, which is also true of the PyTorch reference. PyTorch finished the run in 16.6 hours wall-clock; at 28.8k tok/s the Go run is about 19.5 hours of compute for the same budget (I bounced it on and off the card during the stack work, so its raw wall-clock isn’t representative).

The number I actually wanted is the loss curve. Same seed, same tokens, same schedule, two completely different frameworks. Validation loss, every 2000 steps:

steplmkit-golmkit (PyTorch)
2,0003.002.95
4,0002.682.77
6,0002.502.63
8,0002.502.59
10,0002.432.39
12,0002.332.39
14,0002.382.33

They track within ~0.1 the whole way and cross twice: the Go run is actually ahead from step 4k to 8k, PyTorch edges back in front by the end. That’s the result I cared about. The two stacks learn the same model off the same seed. The remaining differences are within the noise you’d get from re-seeding a single framework, and they come from the frameworks initializing weights differently under the same seed, not from anything diverging in the math.

So the throughput gap, mostly, is closed, and it didn’t close by tuning lmkit-go. It closed by fixing GoMLX. The memory gap is the next thing to chase.

The side-quest: contributing to the Go ML stack

Writing a trainer on a young stack means you find its holes by falling into them. lmkit-go found two big ones, and both fixes landed upstream.

The 26x bug. Early on, lmkit-go was not 15% slower than PyTorch, it was 26x slower, and the profile made no sense: the matmuls were fine, but the gradient step was eating everything. The cause was in GoMLX’s autodiff. The backward pass for a matmul (DotGeneral) computes a couple of matmuls of its own for the weight gradients, and GoMLX was lowering those as a broadcasted multiply followed by a reduce, instead of as a real GEMM. On a GPU that means the weight gradients ran on the CUDA cores while the tensor cores sat idle. Filed it, the GoMLX maintainer turned a fix around fast, and it’s merged. The training step dropped from ~3.0s to ~0.16s, about 18x, and the 26x gap became the 15% gap above.

The lesson there is boring and correct: on real hardware, how a primitive lowers matters more than how many of them you emit. A whole training loop can be quietly running on the wrong half of the chip.

No flash attention. The other hole was memory. GoMLX had no fused, memory-efficient attention, so the 100M model materialized the full attention scores and OOM’d at seq 2048 on the 8GB card, the exact thing PyTorch avoids with F.scaled_dot_product_attention. The fix turned out not to need a kernel at all: emit the right StableHLO custom_call and XLA lowers it straight to cuDNN’s flash attention, forward and backward. That work is now being reshaped into a typed FusedScaledDotProductAttention op with a proper VJP, spread across the three repos in the stack (compute, go-xla and gomlx), per the maintainer’s design. The deep dive on that lives in its own post: flash attention in pure Go, by calling cuDNN through XLA.

Neither fix was lmkit-go code. Both were the substrate. That’s the tax on building high on a stack that’s still filling in, and also the upside: the fixes help everyone downstream, not just one trainer.

Where this lands

Pure Go on XLA trains a 100M model from scratch to a Chinchilla budget at 85% of PyTorch’s throughput on the same 8GB card, on the same loss curve. The cost is memory, 1.7x, and that’s the next thing to go after.

What I didn’t expect was how little of the work was in lmkit-go. The trainer is a few thousand lines of ordinary Go: a loop, an optimizer, a data loader, checkpoints. The hard parts, the parts that decided whether this was 26x slower or 15% slower, all lived one layer down in GoMLX and XLA. Build high enough on a young stack and your performance work turns into other people’s bug reports. The flip side is that the fixes are upstream now, so the next person who tries this starts where I finished.