Skip to main content

Command Palette

Search for a command to run...

Grokking of MLP on Modular Addition

Updated

Background

Grokking is a phenomenon where a model quickly achieves near-perfect training accuracy (memorization), while validation accuracy remains near chance (or even worse-than-chance) for a long time, and then later transitions sharply to strong generalization after extended optimization. I came across this concept while reading Provable Scaling Laws of Feature Emergence from Learning Dynamics of Grokking, and found it interesting. It’s been observed for quite a few years, but I think it would still be interesting to reproducde it by my self.

The experiment is setup with the help of ChatGPT 5.2 and also inspired by Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets.

Experiment Setup

Dataset

The full dataset (training and validation combined) is generated by the following procedure:

  • Choose a prime modulus p (here p = 97).

  • Generate the full set of input pairs (a, b) where a, b in {0, ..., p-1}.

  • Labels are modular addition: y = (a + b) mod p

  • Total dataset size is p^2 = 97^2 = 9409 examples.

The full dataset is then partitioned into training and validation sets:

  • A random permutation of the 9409 pairs is created.

  • The first train_frac * 9409 pairs are used for training; the remainder are validation.

  • Validation corresponds to unseen pairs (not unseen values of a or b).

Model

This experiment runs with MLP (multi-layer perceptron), different from transformers which are used in above papers.

Two integers (a, b) are embedded via a learned embedding table: Embedding(p, embed_dim) with embed_dim = 128. The embeddings for a and b are concatenated into a 256-dim vector. This vector is then passed through a depth-3 MLP with ReLU nonlinearity:

class MLP(nn.Module):
    def __init__(self, p: int, embed_dim: int, hidden_dim: int, depth: int) -> None:
        super().__init__()
        self.emb = nn.Embedding(p, embed_dim)
        layers = []
        in_dim = 2 * embed_dim
        for _ in range(depth - 1):
            layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            in_dim = hidden_dim
        layers += [nn.Linear(in_dim, p)]
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x: [B,2]
        a = self.emb(x[:, 0])
        b = self.emb(x[:, 1])
        h = torch.cat([a, b], dim=-1)
        return self.net(h)
  • p = 97

  • embed_dim = 128

  • hidden_dim = 256

  • depth = 3

Training and Evaluation

The model is trained with following parameters:

  • Optimizer: AdamW

  • Learning rate: lr = 1e-3

  • Weight decay: weight_decay = 1e-2

  • Batch size: 256

  • Mini-batches are sampled uniformly with replacement from the training set.

And it is evaluated every eval_every steps. The training and validation loss/accuracy are recorded.

Results

Across runs, the training dynamics consistently show:

  1. Memorization phase: training accuracy rises to ~100% in the first ~500-800 steps.

  2. Delayed generalization phase (grokking):

    • validation accuracy stays near 0% for a long time

    • then transitions sharply to near-100% after extended training

An example is train_frac = 0.25:

  • At step 500 (hard to see on the graph), the model already has training loss close to 1 but low validation accuracy ~0.00028 (0.028%) with very high validation loss.

  • The validation loss starts to decrease, and validation accuracy increases after extensive training (about 0.9 million steps).

  • Validation accuracy snaps to ~1.0 and validation loss collapses near zero around step 1.2 miilion.

This is characteristic grokking behavior: the model first fits a brittle solution that fails on unseen pairs, then later finds a structured rule that generalizes.

The following table shows the metrics with different train_frac

train_fractotal stepsbest val accstep(val_acc>=0.99)Notes
0.251,500,0001.0001,126,000Clear grokking
0.30800,0001.000637,000Extension reaches grokking
0.35500,0001.000447,500Clear grokking
0.40400,0001.000340,800Clear grokking
0.45300,0001.000267,400Clear grokking
0.50200,0000.902n/aA bit undertrained. Val acc should increase with more training

The following figure illustrates the training fraction vs grokking steps. The grokking step is defined as the steps when validation accuracy reaches 0.9. It can be seen that the larger training fraction is, the smaller the steps needed for grokking, which is consistent with the papers.

Next Steps

This experiment is very limited, in that it only shows grokking happens with MLP, and only tests the relationship between training fraction and grokking.

It would be good to run the experiment with transformers.

Also the papers claim that weight decay is another important factor that affects grokking, which is also an interesting next step.