Contribution:

  1. To enable more grounded vision and language representation learning, introduce a contrastive loss (from CLIP) to ALign the image and text representations BEfore Fusing (ALBEF) them through cross-modal attention, which
  2. To improve learning from noisy web data, propose momentum distillation, a self-training method which learns from pseudo-targets produced by a momentum model.

Model Architecture

ALBEF contains

  • an image encoder: 12-layer visual transformer ViT-B/16
  • a text encoder: first 6 layers of BERT
    • An input image $I$ is encoded into a sequence of embeddings: ${v_{cls}, v_1,…,v_N}$
  • a multimodal encoder: last 6 layers of BERT with additional cross-attention layers
    • An input text $T$ is encoded into a sequence of embeddings: ${w_{cls}, w_1,…,w_N}$, which is fed to the multimodal encoder.

The image features are fused with the text features through cross attention at each layer of the multimodal encoder.

Pre-training Objectives

Pre-train ALBEF with three objectives:

  1. image-text contrastive learning (ITC) on the unimodal encoders; Align

    • Aims to learn better unimodal representations before fusion.

    • It learns a similarity function $s = g_v(v_{cls})^\intercal g_w(w_{cls})$, such that parallel image-text pairs have higher similarity scores.

    • $g_v$ and $g_w$ are linear transformations that map the [CLS] embeddings to normalized lower-dimensional (256-d) representations. (downsampling/normalization)

    • The normalized features from the momentum encoders are denoted as $g^{’}v(v^{’}{cls})$ and $g^{’}w(w^{’}{cls})$ . We define $s(I,T) = g_v(v_{cls})^\intercal g^{’}w(w^{’}{cls})$, and $s(T,I) = g_w(w_{cls})^\intercal g^{’}v(v^{’}{cls})$.

    • For each image and text, we calculate the softmax-normalized image-to-text and text-to-image similarity as:

    • The image-text contrastive loss is defined as the cross-entropy $H$ between $\mathbf{p}$ and $\mathbf{y}$:

      • $y^{i2t}(I)$ and $y^{t2i}(T)$ denote the ground-truth one-hot similarity, where negative pairs have a probability of 0 and the positive pair has a probability of 1.
  2. masked language modeling (MLM) on the multimodal encoder.

    • Randomly mask out the input tokens with a probability of 15% and replace them with the special token [MASK]; utilizes both the image and the contextual text to predict the masked words.

    • MLM minimizes a cross-entropy loss:

      • $\hat{T}$ : a masked text
      • $\mathbf{p}^{msk}(I, \hat{T})$ : the model’s predicted probability for a masked token.
      • $\mathbf{y}^{msk}$: one-hot vocabulary distribution where the ground-truth token has a probability of 1.
  3. image-text matching (ITM) on the multimodal encoder.

    • Predicts whether a pair of image and text is positive (matched) or negative (not matched).

    • We use the multimodal encoder’s output embedding of the [CLS] token as the joint representation of the image-text pair, and append a fully-connected (FC) layer followed by softmax to predict a two-class probability $p^{itm}$.

    • The ITM loss:

      • $\mathbf{y}^{itm}$: 2-dimensional one-hot vector representing the ground-truth label.
    • Propose contrastive hard negative mining which selects informative negatives with higher contrastive similarity.

The full pre-training objective of ALBEF is:

$$ L = L_{itc} + L_{mlm} + L_{itm} $$

Momentum Distillation

Challenge: web data noisy.

  • Negative texts for an image may also match the image’s content (ITC);
  • There may exist other words different from the annotation that describes the image equally well (MLM).

However, the one-hot labels for ITC and MLM penalize all negative predictions regardless of their correctness.

Solution: learn from pseudo-targets generated by the **momentum model (**exponential-moving-average versions of the unimodal and multimodal encoders).

During training, we train the base model such that its predictions match the ones from the momentum model.

  • $ITC_{MoD}$ loss:

  • $MLM_{MoD}$ loss:

    • $q^{msk}(I,\hat{T})$ denote the momentum model’s prediction probability for the 8masked token

So actually Albef has 5 loss functions: 2 $L_{itc}$, 2 $L_{mlm}$, 1 $L_{itm}$

Downstream V+L Tasks

  • Image-Text Retrieval : image-to-text retrieval (TR) and text-to-image retrieval(IR).

  • Visual Entailment: predict whether the relationship between an image and a text is entailment, neutral, or contradictory; consider VE as a three-way classification problem.

  • Visual Question Answering:

    • use a 6-layer transformer decoder to generate the answer.

  • Natural Language for Visual Reasoning: predict whether a text describes a pair of images.

  • Visual Grounding: localize the region in an image that corresponds to a specific textual description.

Reference: