Web Analytics
Skip to content

DINO: Self-Distillation

▶ Try in Colab

If contrastive learning says:

"pull positives together, push negatives apart,"

DINO-style learning says:

"make a student network match a teacher network across different views."

DINOv2 is one of the strongest examples of this idea at scale. It trains a powerful visual encoder without manual labels and produces features that transfer well to many tasks:

  • classification,
  • retrieval,
  • segmentation,
  • depth estimation,
  • dense patch matching.

dino

The Teacher-Student Setup

The key actors are:

  • a student network,
  • a teacher network,
  • multiple crops or views of the same image.

Both networks process different augmentations of the same image. The student is trained to match the teacher outputs.

The teacher is not trained by backpropagation in the usual way. Instead, it is updated as an exponential moving average (EMA) of the student:

\[ \theta_{\text{teacher}} \leftarrow \lambda \theta_{\text{teacher}} + (1-\lambda)\theta_{\text{student}} \]

This makes the teacher a slowly moving target, which stabilizes training.

Multi-Crop Intuition

One image can generate:

  • a couple of large global crops,
  • several small local crops.

The model learns that these all describe the same underlying scene or object. This encourages invariance across scale, viewpoint, and partial observation.

DINO Loss

Let:

  • \(p_t(x)\) be the teacher output distribution,
  • \(p_s(x)\) be the student output distribution.

DINO uses a cross-entropy-like loss:

\[ \mathcal{L}_{\text{DINO}} = - \sum_k p_t^{(k)} \log p_s^{(k)} \]

where the teacher target is treated as fixed for that step.

Avoid Collapse

A beginner might worry:

If the student just copies the teacher, couldn’t both output the same boring constant vector forever?

Yes, that is exactly the collapse problem. From the figure below, we can get the higher loss by just simply predicting the singular probability.

dino

DINO-style methods avoid it using a combination of:

  • multi-crop views,
  • teacher EMA updates,
  • centering,
  • sharpening,

These ingredients make the teacher signal informative enough to organize the representation space instead of flattening it.

We have introduced the first two. Let us introduce the later two important stabilizers to avoid collapsing:

1. Teacher Centering

The teacher logits are centered before softmax to avoid collapse:

\[ \tilde{q}_t = q_t - c \]

where \(c\) is a running center.

2. Sharpening

The teacher distribution uses a lower temperature, which makes it sharper:

\[ p_t = \mathrm{softmax}\left(\frac{\tilde{q}_t}{T_t}\right) \]

The student uses its own temperature:

\[ p_s = \mathrm{softmax}\left(\frac{q_s}{T_s}\right) \]

Then the student is trained to match the teacher.

DINO Code

Here is a sample code to define DINO loss.

import torch
import torch.nn.functional as F


def dino_loss(student_logits, teacher_logits, center, student_temp=0.1, teacher_temp=0.04):
    student_probs = F.log_softmax(student_logits / student_temp, dim=-1)
    teacher_probs = F.softmax((teacher_logits - center) / teacher_temp, dim=-1)
    loss = -(teacher_probs * student_probs).sum(dim=-1).mean()
    return loss

This is simplified, but it captures the central pattern:

  • teacher produces a probability target,
  • student matches it with cross-entropy,
  • center and temperatures help stabilize learning.

EMA Teacher Update

@torch.no_grad()
def update_teacher(student, teacher, momentum=0.996):
    for p_s, p_t in zip(student.parameters(), teacher.parameters()):
        p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)

This is one of the most elegant ideas in self-supervised learning: the teacher improves by being a moving average of the student rather than a separately trained network.

Simplified Training Skeleton

for global_views, local_views in train_loader:
    global_views = [v.to(device) for v in global_views]
    local_views = [v.to(device) for v in local_views]

    with torch.no_grad():
        teacher_out_1 = teacher(global_views[0])
        teacher_out_2 = teacher(global_views[1])

    student_views = global_views + local_views
    student_outs = [student(v) for v in student_views]

    loss = 0.0
    for s_out in student_outs:
        loss += dino_loss(s_out, teacher_out_1, center)
        loss += dino_loss(s_out, teacher_out_2, center)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    update_teacher(student, teacher)

This hides many engineering details, but the conceptual loop is correct.

References and Further Reading