Pseudo-labeling
-
We need a dataset, no dataset exists where all tasks overlap
- Train specialist models on disjoint datasets
- Pseudo-label large RGB dataset
- Train a single model on the resulting aligned multi-modal dataset
-
All 4M models were trained using an aligned multimodal dataset that we created by pseudo labeling Conceptual Captions 12M (CC12M.)
- CC12M already contains aligned RGB images and captions
-
Surface normals & depth ⇒ use DPT-Hybrid
-
Semantic segmentation ⇒ Mask2Former (SwinB backbone)
-
Bouding boxes ⇒ ViTDet ViTH. Fitler the detected bounding boxes by removing all instances with a confidence score below 0.6
-
CLIP feature maps ⇒ last transformer layer of CLIP-B16
- To visualize them in the paper, project onto the first three principal components to get RGB channels.
Multimodal Transformer
- Weight tying: As it is commonly done for autoregressive models, the parameters of the input and output embeddings (i.e., the final linear layer) of the decoder are shared for sequence modalities
Multimodal encoder
- The encoder is a standard Transformer encoder
- Each modality has an input embedding layer to map token indices to vectors
- To each token of a specific modality, we add a learnable modality embedding and either 1D (for sequences) or 2D (for dense modalities) sine-cosine positional embeddings
Multimodal decoder
- The decoder handles tokens from both dense image-like and sequence-like modalities, with each type requiring a different approach.
Common aspects
- they can all freely attend to any encoder tokens in the cross-attention layers, ensuring full access to the encoded information
- they employ attention masks to separate decoder tokens of different modalities.
- This ensures that the decoder produces consistent outputs for each specific modality, irrespective of what other outputs are being generated simultaneously
Difference
- For dense image-like modalities, the decoder input consists of mask tokens along with modality and positional information.
- The decoder’s role is to predict this masked content.
- For sequence-like modalities, the input to the decoder comprises modality, positional, and content information.
- The decoder is tasked to predict the next token in the sequence
- To ensure that each token is only influenced by preceding tokens (and not by any future tokens), we apply a causal mask to the self-attention, as is standard in autoregressive models.
Multimodal masking strategy
- Similar to MultiMAE
- Dropping masked-out tokens and only encoding the small set of visible ones when performing masked image modeling has been shown to yield significant increases in training efficiency
- decouples the decoding training load from the number of supported modalities
Mask sampling parameter
- They sample the number of input and target tokens per modality using a symmetric Dirichlet distribution with concentration parameter
- If is low, the sampling procedure will often choose cases where most of the tokens are sampled from only one modality.
- If is high, however, most samples will contain tokens from all modalities to equal proportions.
- They ablate different choices and get that:
- models trained with higher values generally transfer better to RGB tasks
- models trained with lower input values perform better at transferring to novel input modalities
Input and target masking
- After sampling the per-modality number of input and target tokens
- sample tokens from dense modalities uniformly at random
- Perform span masking on sequence modalities
Span masking
- From T5, given the probability of masking a token (), we randomly mask out tokens in the sequence and replace each consecutive span of masked-out tokens by a sentinel token (e.g., S_1, S_2, S_3,…).
- The target sequence then consists of the masked-out spans delimited by the sentinel tokens, followed by a final sentinel token to signal the end of the sequence.
- Unlike dense modalities, it is not possible to strictly respect the token budget when sampling from sequences due their variable length.
- Instead, they treat the token budget as a strict upper bound and mask sequences as follows:
- For the input, they sample a masking probability from a uniform distribution and use it for span masking. If the sequence length after masking is greater than the input budget, they progressively increase until it fits within the assigned budget.
- For the target, if the sequence does not fit within the budget, they randomly truncate it while also ensuring that the first token of the truncated sequence is a sentinel token.
- Instead, they treat the token budget as a strict upper bound and mask sequences as follows:
Masking budget
Input masking budget
- The difficulty of the multimodal masked modeling task is largely determined by the number of visible (non-masked) input tokens
- with fewer tokens used making the task more challenging.
- Since the modalities contain a lot of spatial information about each other, it is necessary to lower the number of visible tokens to keep the objective difficult enough.
- Hard to decide on a good value, needs to be searched for
Output masking budget
- Decoding all masked-out tokens can quickly become infeasible as the number of modalities grows and the masking ratio is kept high.
- Decoding only a small random subset of all targets performs better (for a fixed training duration) than decoding a larger number of tokens
- while also significantly reducing the computational costs.
Mixture of masking strategies
- Pre-training using RGB as the sole input modality instead of training on performs significantly better at transfers that, likewise, use only RGB as input.
- On the flip-side, pre-training with all modalities performs significantly better when transferring to unseen input modalities.
- We can find a compromise by sampling batch elements, where approximately half the time input tokens are RGB-only, and half the time they are sampled from all modalities.
- Results in a very good generalist model.
Generation
Generation procedure
- Unlike traditional autoregressive models which need to generate tokens one-by-one in a pre-determined order
- Using our encoder-decoder, we can speed up inference by parallelizing the decoding process because the distribution over each masked token can be predicted at the same time
MaskGIT (for image-like modalities)
- This parallel decoding scheme generates the entire image in a pre-determined number of steps .
- At every prediction step, encode all visible tokens and decode all masked out tokens.
- Sample from their predicted distributions and choose the most confident tokens, where is determined by the generation schedule.
- Add the predicted tokens to the input and repeat the previous steps.
- A masking schedule needs to be set. Usual practice is to make the model predict increasingly more tokens, as the “information content” of the modality becomes clearer.
Random order autoregressive (ROAR) (for image-like modalities)
- Same idea as MaskGIT (in terms of progressively sampling more tokens.) However, unlike MaskGIT, you do not decode all masked-out tokens at every step, but instead randomly select tokens to decode.
Left-to-right autoregressive (for sequences)
- Just gernerate sequence modalities such as captions by auto-regressively decoding them, using the Transformer decoder.
Multimodal weighted guidance
- Inspired by classifier-free guidance in diffusion models.
- Multimodal guidance can be achieved by computing a weighted sum of the logits of an unconditional and each conditional case:
Technical details
Training duration
- was trained for 8 days on 128 A100 GPUs.
- used activation checkpointing + ZeRO-2 (FSDP)
Training loss
- When predicting multiple modalities at the same time, there are several ways to compute their losses.
- Per-token loss: Treat every predicted token the same and average all their losses.
- This setting is biased against modalities that don’t contain a lot of tokens, such as captions.
- Per-modality loss: First average the loss for every target modality individually and then average those.
- Per-token loss: Treat every predicted token the same and average all their losses.
- Computing the loss per-modality noticeably outperforms the per-token loss.
Data loading
Training with repeat sampling
- Data loading can be a significant bottleneck when training efficient masked models.
- They use webdataset to load tar files consisting of 1000 samples instead of loading individual samples
- They keep a buffer of loaded images in RAM and randomly sample from it repeatedly – each time applying different random masking.
- Whenever an element in the buffer has been sampled more than the specified number of repeats (they use 4), they replace it with a fresh sample
- Improves training efficiency at no loss in performance