Grokking of MLP on Modular Addition
Background
Grokking is a phenomenon where a model quickly achieves near-perfect training accuracy (memorization), while validation accuracy remains near chance (or even worse-than-chance) for a long time, and then later transitions sharply to strong generalization after extended optimization. I came across this concept while reading Provable Scaling Laws of Feature Emergence from Learning Dynamics of Grokking, and found it interesting. It’s been observed for quite a few years, but I think it would still be interesting to reproducde it by my self.
The experiment is setup with the help of ChatGPT 5.2 and also inspired by Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets.
Experiment Setup
Dataset
The full dataset (training and validation combined) is generated by the following procedure:
Choose a prime modulus p (here p = 97).
Generate the full set of input pairs (a, b) where a, b in {0, ..., p-1}.
Labels are modular addition: y = (a + b) mod p
Total dataset size is p^2 = 97^2 = 9409 examples.
The full dataset is then partitioned into training and validation sets:
A random permutation of the 9409 pairs is created.
The first
train_frac * 9409pairs are used for training; the remainder are validation.Validation corresponds to unseen pairs (not unseen values of
aorb).
Model
This experiment runs with MLP (multi-layer perceptron), different from transformers which are used in above papers.
Two integers (a, b) are embedded via a learned embedding table: Embedding(p, embed_dim) with embed_dim = 128. The embeddings for a and b are concatenated into a 256-dim vector. This vector is then passed through a depth-3 MLP with ReLU nonlinearity:
class MLP(nn.Module):
def __init__(self, p: int, embed_dim: int, hidden_dim: int, depth: int) -> None:
super().__init__()
self.emb = nn.Embedding(p, embed_dim)
layers = []
in_dim = 2 * embed_dim
for _ in range(depth - 1):
layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
in_dim = hidden_dim
layers += [nn.Linear(in_dim, p)]
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B,2]
a = self.emb(x[:, 0])
b = self.emb(x[:, 1])
h = torch.cat([a, b], dim=-1)
return self.net(h)
p = 97
embed_dim = 128
hidden_dim = 256
depth = 3
Training and Evaluation
The model is trained with following parameters:
Optimizer: AdamW
Learning rate:
lr = 1e-3Weight decay:
weight_decay = 1e-2Batch size:
256Mini-batches are sampled uniformly with replacement from the training set.
And it is evaluated every eval_every steps. The training and validation loss/accuracy are recorded.
Results
Across runs, the training dynamics consistently show:
Memorization phase: training accuracy rises to ~100% in the first ~500-800 steps.
Delayed generalization phase (grokking):
validation accuracy stays near 0% for a long time
then transitions sharply to near-100% after extended training
An example is train_frac = 0.25:

At step 500 (hard to see on the graph), the model already has training loss close to 1 but low validation accuracy ~0.00028 (0.028%) with very high validation loss.
The validation loss starts to decrease, and validation accuracy increases after extensive training (about 0.9 million steps).
Validation accuracy snaps to ~1.0 and validation loss collapses near zero around step 1.2 miilion.
This is characteristic grokking behavior: the model first fits a brittle solution that fails on unseen pairs, then later finds a structured rule that generalizes.
The following table shows the metrics with different train_frac
| train_frac | total steps | best val acc | step(val_acc>=0.99) | Notes |
| 0.25 | 1,500,000 | 1.000 | 1,126,000 | Clear grokking |
| 0.30 | 800,000 | 1.000 | 637,000 | Extension reaches grokking |
| 0.35 | 500,000 | 1.000 | 447,500 | Clear grokking |
| 0.40 | 400,000 | 1.000 | 340,800 | Clear grokking |
| 0.45 | 300,000 | 1.000 | 267,400 | Clear grokking |
| 0.50 | 200,000 | 0.902 | n/a | A bit undertrained. Val acc should increase with more training |
The following figure illustrates the training fraction vs grokking steps. The grokking step is defined as the steps when validation accuracy reaches 0.9. It can be seen that the larger training fraction is, the smaller the steps needed for grokking, which is consistent with the papers.

Next Steps
This experiment is very limited, in that it only shows grokking happens with MLP, and only tests the relationship between training fraction and grokking.
It would be good to run the experiment with transformers.
Also the papers claim that weight decay is another important factor that affects grokking, which is also an interesting next step.