A Comprehensive Tutorial on GLIP (Grounded Language-Image Pre-training)
Introduction
GLIP (Grounded Language-Image Pre-training) is a unified model architecture that bridges the gap between vision and language by integrating object detection and phrase grounding tasks. It leverages both visual and textual data to perform object detection conditioned on textual descriptions, enabling the model to recognize objects based on their semantic meanings.
Key Features:
- Unified Architecture: Combines object detection and phrase grounding into a single framework.
- Text-Conditioned Detection: Incorporates textual embeddings into object queries, allowing for detection based on language cues.
- Transformer-Based: Utilizes a transformer encoder-decoder structure to process visual and textual information jointly.
How GLIP Works in Detail
GLIP operates by integrating textual embeddings into the object detection pipeline. Here’s a high-level overview:
- Backbone Network: Extracts visual features from input images using a convolutional neural network (e.g., ResNet).
- Language Transformer: Processes textual inputs (phrases or sentences) to generate text embeddings using models like BERT.
- Object Queries: Text embeddings are transformed into object queries that guide the detection process.
- Transformer Encoder-Decoder: The visual features and object queries are processed through a transformer to produce refined feature representations.
- Prediction Heads: The model outputs bounding boxes and class logits for each object query.
- Loss Function: Combines classification and regression losses to train the model.
Backbone Structure
The backbone network is responsible for extracting rich visual features from input images. We will use a pre-trained ResNet-50 model.
import torch
import torch.nn as nn
import torchvision.models as models
class Backbone(nn.Module):
def __init__(self):
super(Backbone, self).__init__()
# Load a pre-trained ResNet-50 model
resnet = models.resnet50(pretrained=True)
# Remove the fully connected layer and average pooling
self.features = nn.Sequential(*list(resnet.children())[:-2])
def forward(self, x):
# x: Input image tensor of shape [batch_size, 3, H, W]
x = self.features(x)
# x: Visual feature tensor of shape [batch_size, 2048, H', W']
return x
Explanation:
- Input:
x
is a batch of images with shape[batch_size, 3, H, W]
. - Processing: The ResNet-50 model extracts features, reducing spatial dimensions.
- Output: A feature map of shape
[batch_size, 2048, H', W']
, whereH'
andW'
are reduced spatial dimensions.
Language Transformer Structure
We use a pre-trained BERT model to encode textual inputs into embeddings.
from transformers import BertModel, BertTokenizer
class LanguageTransformer(nn.Module):
def __init__(self):
super(LanguageTransformer, self).__init__()
# Load pre-trained BERT model and tokenizer
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def forward(self, text_list):
# text_list: List of strings with length [batch_size]
encoded_input = self.tokenizer(
text_list, return_tensors='pt', padding=True, truncation=True
)
# encoded_input is a dict containing:
# 'input_ids': Tensor of shape [batch_size, seq_len]
# 'attention_mask': Tensor of shape [batch_size, seq_len]
outputs = self.bert(
input_ids=encoded_input['input_ids'],
attention_mask=encoded_input['attention_mask']
)
# outputs.last_hidden_state: Tensor of shape [batch_size, seq_len, hidden_size]
# outputs.pooler_output: Tensor of shape [batch_size, hidden_size]
# Use the pooled output as text embeddings
text_embeddings = outputs.pooler_output # Shape: [batch_size, hidden_size]
return text_embeddings
Explanation:
- Input:
text_list
is a list of strings of length[batch_size]
. - Tokenization: Converts text into token IDs and attention masks.
- BERT Encoding: Processes tokens to generate embeddings.
- Output:
text_embeddings
of shape[batch_size, hidden_size]
(usually 768 for BERT-base).
Incorporating Text Embeddings into Object Queries
Text embeddings are transformed into object queries for the transformer decoder.
class ObjectQueryGenerator(nn.Module):
def __init__(self, hidden_dim, num_queries):
super(ObjectQueryGenerator, self).__init__()
# Linear layer to project BERT embeddings to the transformer's hidden dimension
self.query_proj = nn.Linear(768, hidden_dim)
self.num_queries = num_queries # Number of object queries
def forward(self, text_embeddings):
# text_embeddings: Tensor of shape [batch_size, 768]
projected_embeddings = self.query_proj(text_embeddings)
# projected_embeddings: Tensor of shape [batch_size, hidden_dim]
# Expand embeddings to create multiple object queries
queries = projected_embeddings.unsqueeze(1).repeat(1, self.num_queries, 1)
# queries: Tensor of shape [batch_size, num_queries, hidden_dim]
return queries
Explanation:
- Projection: Maps text embeddings to the transformer’s hidden dimension.
- Expansion: Generates multiple object queries per text embedding.
- Output:
queries
of shape[batch_size, num_queries, hidden_dim]
.
Ground Truth Generation
Ground truth data includes bounding boxes and labels for objects in the images.
def generate_ground_truth(annotations):
"""
annotations: List of dictionaries for each image in the batch.
Each dictionary contains:
- 'boxes': Tensor of shape [num_objects, 4] (x_min, y_min, x_max, y_max)
- 'labels': List of strings with length [num_objects]
"""
targets = []
for ann in annotations:
boxes = ann['boxes'] # Tensor of shape [num_objects, 4]
labels = ann['labels'] # List of strings
targets.append({'boxes': boxes, 'labels': labels})
return targets
Explanation:
- Input:
annotations
, a list containing bounding boxes and labels per image. - Output:
targets
, a list of dictionaries ready for loss computation.
Loss Function
The loss function combines classification and bounding box regression losses.
def compute_loss(outputs, targets):
"""
outputs: Dict containing model predictions:
- 'pred_boxes': Tensor [batch_size, num_queries, 4]
- 'pred_logits': Tensor [batch_size, num_queries, num_classes]
targets: List of dictionaries with ground truth:
- 'boxes': Tensor [num_objects, 4]
- 'labels': Tensor [num_objects]
"""
# Flatten predictions and targets
pred_boxes = outputs['pred_boxes'].view(-1, 4) # [batch_size*num_queries, 4]
pred_logits = outputs['pred_logits'].view(-1, num_classes) # [batch_size*num_queries, num_classes]
# Concatenate targets
target_boxes = torch.cat([t['boxes'] for t in targets], dim=0) # [total_num_objects, 4]
target_labels = torch.cat([t['labels'] for t in targets], dim=0) # [total_num_objects]
# Classification loss (e.g., CrossEntropyLoss)
classification_loss = nn.CrossEntropyLoss()(pred_logits, target_labels)
# Bounding box regression loss (e.g., Smooth L1 Loss)
regression_loss = nn.SmoothL1Loss()(pred_boxes, target_boxes)
# Total loss
total_loss = classification_loss + regression_loss
return total_loss
Explanation:
- Classification Loss: Measures the error in predicted class probabilities.
- Regression Loss: Measures the error in predicted bounding boxes.
- Total Loss: Sum of classification and regression losses.
GLIP Model Integration
Now, we combine all components into the GLIP model.
class GLIPModel(nn.Module):
def __init__(self, hidden_dim=256, num_queries=100, num_classes=91):
super(GLIPModel, self).__init__()
self.backbone = Backbone()
self.language_transformer = LanguageTransformer()
self.query_generator = ObjectQueryGenerator(hidden_dim, num_queries)
# Transformer encoder-decoder
self.transformer = nn.Transformer(
d_model=hidden_dim, nhead=8, num_encoder_layers=6, num_decoder_layers=6
)
# Prediction heads
self.bbox_pred = nn.Linear(hidden_dim, 4) # Bounding box regression
self.class_pred = nn.Linear(hidden_dim, num_classes) # Classification logits
def forward(self, images, texts):
# images: Tensor of shape [batch_size, 3, H, W]
# texts: List of strings of length [batch_size]
# Extract visual features
visual_features = self.backbone(images)
# visual_features: [batch_size, 2048, H', W']
batch_size, c, h, w = visual_features.shape
# Flatten spatial dimensions and permute
visual_features = visual_features.view(batch_size, c, h * w).permute(0, 2, 1)
# visual_features: [batch_size, seq_len_v, feature_dim_v]
# seq_len_v = H' * W', feature_dim_v = 2048
# Encode text inputs
text_embeddings = self.language_transformer(texts)
# text_embeddings: [batch_size, 768]
# Generate object queries
queries = self.query_generator(text_embeddings)
# queries: [batch_size, num_queries, hidden_dim]
# Prepare inputs for transformer
src = visual_features.permute(1, 0, 2) # [seq_len_v, batch_size, feature_dim_v]
tgt = queries.permute(1, 0, 2) # [num_queries, batch_size, hidden_dim]
# Transformer encoding and decoding
memory = self.transformer.encoder(src) # [seq_len_v, batch_size, hidden_dim]
hs = self.transformer.decoder(tgt, memory) # [num_queries, batch_size, hidden_dim]
# Permute back to [batch_size, num_queries, hidden_dim]
hs = hs.permute(1, 0, 2)
# Predict bounding boxes and class logits
pred_boxes = self.bbox_pred(hs) # [batch_size, num_queries, 4]
pred_logits = self.class_pred(hs) # [batch_size, num_queries, num_classes]
return {'pred_boxes': pred_boxes, 'pred_logits': pred_logits}
Explanation:
- Visual Feature Extraction: Backbone processes images to get visual features.
- Text Embedding: Language transformer encodes text into embeddings.
- Object Queries: Generated from text embeddings.
- Transformer Processing: Jointly processes visual features and object queries.
- Predictions: Outputs bounding boxes and class logits for each query.
Mixing Text Embeddings with Object Queries
Text embeddings guide the object queries to focus on relevant objects.
# In ObjectQueryGenerator.forward():
def forward(self, text_embeddings):
# text_embeddings: Tensor of shape [batch_size, 768]
projected_embeddings = self.query_proj(text_embeddings)
# projected_embeddings: Tensor of shape [batch_size, hidden_dim]
# Normalize embeddings (optional)
projected_embeddings = projected_embeddings / projected_embeddings.norm(dim=-1, keepdim=True)
# Create learnable query embeddings (optional)
learnable_queries = nn.Parameter(torch.randn(self.num_queries, self.hidden_dim))
# learnable_queries: [num_queries, hidden_dim]
# Combine text embeddings with learnable queries
queries = projected_embeddings.unsqueeze(1) + learnable_queries.unsqueeze(0)
# queries: [batch_size, num_queries, hidden_dim]
return queries
Explanation:
- Combination: Text embeddings are added to learnable query embeddings.
- Result: Queries are conditioned on textual information, guiding detection.
Training Code
The training loop involves feeding data through the model, computing loss, and updating weights.
def train(model, dataloader, optimizer, num_epochs):
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
images = batch['images'] # Tensor [batch_size, 3, H, W]
texts = batch['texts'] # List of strings [batch_size]
annotations = batch['annotations'] # Ground truth data
# Forward pass
outputs = model(images, texts)
# outputs: Dict with 'pred_boxes' and 'pred_logits'
# Generate ground truth targets
targets = generate_ground_truth(annotations)
# Compute loss
loss = compute_loss(outputs, targets)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
Explanation:
- Data Loading: Each batch contains images, texts, and annotations.
- Forward Pass: Model predicts outputs based on inputs.
- Loss Computation: Calculates the loss using the custom loss function.
- Backpropagation: Updates model parameters.
Inference Code
Perform inference to get predictions from the trained model.
def inference(model, images, texts, conf_threshold=0.5):
model.eval()
with torch.no_grad():
# images: Tensor [batch_size, 3, H, W]
# texts: List of strings [batch_size]
outputs = model(images, texts)
# outputs: Dict with 'pred_boxes' and 'pred_logits'
# Apply softmax to class logits to get probabilities
pred_probs = nn.Softmax(dim=-1)(outputs['pred_logits'])
# pred_probs: Tensor [batch_size, num_queries, num_classes]
# Get the highest class probability and index
pred_scores, pred_labels = torch.max(pred_probs, dim=-1)
# pred_scores: Tensor [batch_size, num_queries]
# pred_labels: Tensor [batch_size, num_queries]
# Retrieve predicted bounding boxes
pred_boxes = outputs['pred_boxes'] # [batch_size, num_queries, 4]
# Filter out predictions below confidence threshold
mask = pred_scores > conf_threshold
filtered_boxes = []
filtered_labels = []
filtered_scores = []
for i in range(images.size(0)):
boxes = pred_boxes[i][mask[i]]
labels = pred_labels[i][mask[i]]
scores = pred_scores[i][mask[i]]
filtered_boxes.append(boxes)
filtered_labels.append(labels)
filtered_scores.append(scores)
return filtered_boxes, filtered_labels, filtered_scores
Explanation:
- Evaluation Mode: Disables training-specific layers like dropout.
- Prediction Processing: Converts logits to probabilities and selects top predictions.
- Filtering: Removes low-confidence predictions.
Handling Unknown Objects
If the object in the scene is unknown or not present in the predefined classes:
- Zero Probability: The model may assign low or zero probability to unknown objects.
- Class “Unknown”: Include an “unknown” class during training to handle unseen objects.
- Detection Failure: The model might not generate bounding boxes for unknown objects.
Summary of Tensor Shapes
- Images:
[batch_size, 3, H, W]
- Visual Features:
[batch_size, 2048, H', W']
- Flattened Visual Features:
[batch_size, seq_len_v, 2048]
, whereseq_len_v = H' * W'
- Text Embeddings:
[batch_size, 768]
- Projected Text Embeddings:
[batch_size, hidden_dim]
- Object Queries:
[batch_size, num_queries, hidden_dim]
- Transformer Source (
src
):[seq_len_v, batch_size, hidden_dim]
- Transformer Target (
tgt
):[num_queries, batch_size, hidden_dim]
- Transformer Output (
hs
):[num_queries, batch_size, hidden_dim]
- Permuted Transformer Output:
[batch_size, num_queries, hidden_dim]
- Predicted Boxes:
[batch_size, num_queries, 4]
- Predicted Logits:
[batch_size, num_queries, num_classes]
Conclusion
In this tutorial, we’ve explored the GLIP model in detail, including:
- The integration of visual and textual information.
- The architectural components and their functions.
- How text embeddings are incorporated into object queries.
- The training and inference processes with detailed code examples.
GLIP represents a significant advancement in unifying vision and language tasks, enabling more sophisticated and context-aware object detection systems.
Additional Notes
- Extensibility: The model can be extended to handle more complex language inputs or larger vocabularies.
- Performance Optimization: Techniques like mixed-precision training and distributed computing can improve training efficiency.
- Data Requirements: Requires datasets that contain both visual and textual annotations.
By following this tutorial, you should have a solid understanding of how GLIP works and how to implement it for your own applications.