Home

CV

CV

A Comprehensive Tutorial on GLIP (Grounded Language-Image Pre-training)

Introduction

GLIP (Grounded Language-Image Pre-training) is a unified model architecture that bridges the gap between vision and language by integrating object detection and phrase grounding tasks. It leverages both visual and textual data to perform object detection conditioned on textual descriptions, enabling the model to recognize objects based on their semantic meanings.

Key Features:

  • Unified Architecture: Combines object detection and phrase grounding into a single framework.
  • Text-Conditioned Detection: Incorporates textual embeddings into object queries, allowing for detection based on language cues.
  • Transformer-Based: Utilizes a transformer encoder-decoder structure to process visual and textual information jointly.

GLIP

How GLIP Works in Detail

GLIP operates by integrating textual embeddings into the object detection pipeline. Here’s a high-level overview:

  1. Backbone Network: Extracts visual features from input images using a convolutional neural network (e.g., ResNet).
  2. Language Transformer: Processes textual inputs (phrases or sentences) to generate text embeddings using models like BERT.
  3. Object Queries: Text embeddings are transformed into object queries that guide the detection process.
  4. Transformer Encoder-Decoder: The visual features and object queries are processed through a transformer to produce refined feature representations.
  5. Prediction Heads: The model outputs bounding boxes and class logits for each object query.
  6. Loss Function: Combines classification and regression losses to train the model.

Backbone Structure

The backbone network is responsible for extracting rich visual features from input images. We will use a pre-trained ResNet-50 model.

import torch
import torch.nn as nn
import torchvision.models as models

class Backbone(nn.Module):
    def __init__(self):
        super(Backbone, self).__init__()
        # Load a pre-trained ResNet-50 model
        resnet = models.resnet50(pretrained=True)
        # Remove the fully connected layer and average pooling
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
    def forward(self, x):
        # x: Input image tensor of shape [batch_size, 3, H, W]
        x = self.features(x)
        # x: Visual feature tensor of shape [batch_size, 2048, H', W']
        return x

Explanation:

  • Input: x is a batch of images with shape [batch_size, 3, H, W].
  • Processing: The ResNet-50 model extracts features, reducing spatial dimensions.
  • Output: A feature map of shape [batch_size, 2048, H', W'], where H' and W' are reduced spatial dimensions.

Language Transformer Structure

We use a pre-trained BERT model to encode textual inputs into embeddings.

from transformers import BertModel, BertTokenizer

class LanguageTransformer(nn.Module):
    def __init__(self):
        super(LanguageTransformer, self).__init__()
        # Load pre-trained BERT model and tokenizer
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
    def forward(self, text_list):
        # text_list: List of strings with length [batch_size]
        encoded_input = self.tokenizer(
            text_list, return_tensors='pt', padding=True, truncation=True
        )
        # encoded_input is a dict containing:
        # 'input_ids': Tensor of shape [batch_size, seq_len]
        # 'attention_mask': Tensor of shape [batch_size, seq_len]
        
        outputs = self.bert(
            input_ids=encoded_input['input_ids'],
            attention_mask=encoded_input['attention_mask']
        )
        # outputs.last_hidden_state: Tensor of shape [batch_size, seq_len, hidden_size]
        # outputs.pooler_output: Tensor of shape [batch_size, hidden_size]
        
        # Use the pooled output as text embeddings
        text_embeddings = outputs.pooler_output  # Shape: [batch_size, hidden_size]
        return text_embeddings

Explanation:

  • Input: text_list is a list of strings of length [batch_size].
  • Tokenization: Converts text into token IDs and attention masks.
  • BERT Encoding: Processes tokens to generate embeddings.
  • Output: text_embeddings of shape [batch_size, hidden_size] (usually 768 for BERT-base).

Incorporating Text Embeddings into Object Queries

Text embeddings are transformed into object queries for the transformer decoder.

class ObjectQueryGenerator(nn.Module):
    def __init__(self, hidden_dim, num_queries):
        super(ObjectQueryGenerator, self).__init__()
        # Linear layer to project BERT embeddings to the transformer's hidden dimension
        self.query_proj = nn.Linear(768, hidden_dim)
        self.num_queries = num_queries  # Number of object queries
        
    def forward(self, text_embeddings):
        # text_embeddings: Tensor of shape [batch_size, 768]
        projected_embeddings = self.query_proj(text_embeddings)
        # projected_embeddings: Tensor of shape [batch_size, hidden_dim]
        
        # Expand embeddings to create multiple object queries
        queries = projected_embeddings.unsqueeze(1).repeat(1, self.num_queries, 1)
        # queries: Tensor of shape [batch_size, num_queries, hidden_dim]
        return queries

Explanation:

  • Projection: Maps text embeddings to the transformer’s hidden dimension.
  • Expansion: Generates multiple object queries per text embedding.
  • Output: queries of shape [batch_size, num_queries, hidden_dim].

Ground Truth Generation

Ground truth data includes bounding boxes and labels for objects in the images.

def generate_ground_truth(annotations):
    """
    annotations: List of dictionaries for each image in the batch.
    Each dictionary contains:
        - 'boxes': Tensor of shape [num_objects, 4] (x_min, y_min, x_max, y_max)
        - 'labels': List of strings with length [num_objects]
    """
    targets = []
    for ann in annotations:
        boxes = ann['boxes']  # Tensor of shape [num_objects, 4]
        labels = ann['labels']  # List of strings
        targets.append({'boxes': boxes, 'labels': labels})
    return targets

Explanation:

  • Input: annotations, a list containing bounding boxes and labels per image.
  • Output: targets, a list of dictionaries ready for loss computation.

Loss Function

The loss function combines classification and bounding box regression losses.

def compute_loss(outputs, targets):
    """
    outputs: Dict containing model predictions:
        - 'pred_boxes': Tensor [batch_size, num_queries, 4]
        - 'pred_logits': Tensor [batch_size, num_queries, num_classes]
    targets: List of dictionaries with ground truth:
        - 'boxes': Tensor [num_objects, 4]
        - 'labels': Tensor [num_objects]
    """
    # Flatten predictions and targets
    pred_boxes = outputs['pred_boxes'].view(-1, 4)  # [batch_size*num_queries, 4]
    pred_logits = outputs['pred_logits'].view(-1, num_classes)  # [batch_size*num_queries, num_classes]
    
    # Concatenate targets
    target_boxes = torch.cat([t['boxes'] for t in targets], dim=0)  # [total_num_objects, 4]
    target_labels = torch.cat([t['labels'] for t in targets], dim=0)  # [total_num_objects]
    
    # Classification loss (e.g., CrossEntropyLoss)
    classification_loss = nn.CrossEntropyLoss()(pred_logits, target_labels)
    
    # Bounding box regression loss (e.g., Smooth L1 Loss)
    regression_loss = nn.SmoothL1Loss()(pred_boxes, target_boxes)
    
    # Total loss
    total_loss = classification_loss + regression_loss
    return total_loss

Explanation:

  • Classification Loss: Measures the error in predicted class probabilities.
  • Regression Loss: Measures the error in predicted bounding boxes.
  • Total Loss: Sum of classification and regression losses.

GLIP Model Integration

Now, we combine all components into the GLIP model.

class GLIPModel(nn.Module):
    def __init__(self, hidden_dim=256, num_queries=100, num_classes=91):
        super(GLIPModel, self).__init__()
        self.backbone = Backbone()
        self.language_transformer = LanguageTransformer()
        self.query_generator = ObjectQueryGenerator(hidden_dim, num_queries)
        
        # Transformer encoder-decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim, nhead=8, num_encoder_layers=6, num_decoder_layers=6
        )
        
        # Prediction heads
        self.bbox_pred = nn.Linear(hidden_dim, 4)  # Bounding box regression
        self.class_pred = nn.Linear(hidden_dim, num_classes)  # Classification logits
        
    def forward(self, images, texts):
        # images: Tensor of shape [batch_size, 3, H, W]
        # texts: List of strings of length [batch_size]
        
        # Extract visual features
        visual_features = self.backbone(images)
        # visual_features: [batch_size, 2048, H', W']
        batch_size, c, h, w = visual_features.shape
        
        # Flatten spatial dimensions and permute
        visual_features = visual_features.view(batch_size, c, h * w).permute(0, 2, 1)
        # visual_features: [batch_size, seq_len_v, feature_dim_v]
        # seq_len_v = H' * W', feature_dim_v = 2048
        
        # Encode text inputs
        text_embeddings = self.language_transformer(texts)
        # text_embeddings: [batch_size, 768]
        
        # Generate object queries
        queries = self.query_generator(text_embeddings)
        # queries: [batch_size, num_queries, hidden_dim]
        
        # Prepare inputs for transformer
        src = visual_features.permute(1, 0, 2)  # [seq_len_v, batch_size, feature_dim_v]
        tgt = queries.permute(1, 0, 2)  # [num_queries, batch_size, hidden_dim]
        
        # Transformer encoding and decoding
        memory = self.transformer.encoder(src)  # [seq_len_v, batch_size, hidden_dim]
        hs = self.transformer.decoder(tgt, memory)  # [num_queries, batch_size, hidden_dim]
        
        # Permute back to [batch_size, num_queries, hidden_dim]
        hs = hs.permute(1, 0, 2)
        
        # Predict bounding boxes and class logits
        pred_boxes = self.bbox_pred(hs)  # [batch_size, num_queries, 4]
        pred_logits = self.class_pred(hs)  # [batch_size, num_queries, num_classes]
        
        return {'pred_boxes': pred_boxes, 'pred_logits': pred_logits}

Explanation:

  • Visual Feature Extraction: Backbone processes images to get visual features.
  • Text Embedding: Language transformer encodes text into embeddings.
  • Object Queries: Generated from text embeddings.
  • Transformer Processing: Jointly processes visual features and object queries.
  • Predictions: Outputs bounding boxes and class logits for each query.

Mixing Text Embeddings with Object Queries

Text embeddings guide the object queries to focus on relevant objects.

# In ObjectQueryGenerator.forward():
def forward(self, text_embeddings):
    # text_embeddings: Tensor of shape [batch_size, 768]
    projected_embeddings = self.query_proj(text_embeddings)
    # projected_embeddings: Tensor of shape [batch_size, hidden_dim]
    
    # Normalize embeddings (optional)
    projected_embeddings = projected_embeddings / projected_embeddings.norm(dim=-1, keepdim=True)
    
    # Create learnable query embeddings (optional)
    learnable_queries = nn.Parameter(torch.randn(self.num_queries, self.hidden_dim))
    # learnable_queries: [num_queries, hidden_dim]
    
    # Combine text embeddings with learnable queries
    queries = projected_embeddings.unsqueeze(1) + learnable_queries.unsqueeze(0)
    # queries: [batch_size, num_queries, hidden_dim]
    return queries

Explanation:

  • Combination: Text embeddings are added to learnable query embeddings.
  • Result: Queries are conditioned on textual information, guiding detection.

Training Code

The training loop involves feeding data through the model, computing loss, and updating weights.

def train(model, dataloader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader:
            images = batch['images']  # Tensor [batch_size, 3, H, W]
            texts = batch['texts']    # List of strings [batch_size]
            annotations = batch['annotations']  # Ground truth data
            
            # Forward pass
            outputs = model(images, texts)
            # outputs: Dict with 'pred_boxes' and 'pred_logits'
            
            # Generate ground truth targets
            targets = generate_ground_truth(annotations)
            
            # Compute loss
            loss = compute_loss(outputs, targets)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Explanation:

  • Data Loading: Each batch contains images, texts, and annotations.
  • Forward Pass: Model predicts outputs based on inputs.
  • Loss Computation: Calculates the loss using the custom loss function.
  • Backpropagation: Updates model parameters.

Inference Code

Perform inference to get predictions from the trained model.

def inference(model, images, texts, conf_threshold=0.5):
    model.eval()
    with torch.no_grad():
        # images: Tensor [batch_size, 3, H, W]
        # texts: List of strings [batch_size]
        
        outputs = model(images, texts)
        # outputs: Dict with 'pred_boxes' and 'pred_logits'
        
        # Apply softmax to class logits to get probabilities
        pred_probs = nn.Softmax(dim=-1)(outputs['pred_logits'])
        # pred_probs: Tensor [batch_size, num_queries, num_classes]
        
        # Get the highest class probability and index
        pred_scores, pred_labels = torch.max(pred_probs, dim=-1)
        # pred_scores: Tensor [batch_size, num_queries]
        # pred_labels: Tensor [batch_size, num_queries]
        
        # Retrieve predicted bounding boxes
        pred_boxes = outputs['pred_boxes']  # [batch_size, num_queries, 4]
        
        # Filter out predictions below confidence threshold
        mask = pred_scores > conf_threshold
        filtered_boxes = []
        filtered_labels = []
        filtered_scores = []
        
        for i in range(images.size(0)):
            boxes = pred_boxes[i][mask[i]]
            labels = pred_labels[i][mask[i]]
            scores = pred_scores[i][mask[i]]
            filtered_boxes.append(boxes)
            filtered_labels.append(labels)
            filtered_scores.append(scores)
        
        return filtered_boxes, filtered_labels, filtered_scores

Explanation:

  • Evaluation Mode: Disables training-specific layers like dropout.
  • Prediction Processing: Converts logits to probabilities and selects top predictions.
  • Filtering: Removes low-confidence predictions.

Handling Unknown Objects

If the object in the scene is unknown or not present in the predefined classes:

  • Zero Probability: The model may assign low or zero probability to unknown objects.
  • Class “Unknown”: Include an “unknown” class during training to handle unseen objects.
  • Detection Failure: The model might not generate bounding boxes for unknown objects.

Summary of Tensor Shapes

  • Images: [batch_size, 3, H, W]
  • Visual Features: [batch_size, 2048, H', W']
  • Flattened Visual Features: [batch_size, seq_len_v, 2048], where seq_len_v = H' * W'
  • Text Embeddings: [batch_size, 768]
  • Projected Text Embeddings: [batch_size, hidden_dim]
  • Object Queries: [batch_size, num_queries, hidden_dim]
  • Transformer Source (src): [seq_len_v, batch_size, hidden_dim]
  • Transformer Target (tgt): [num_queries, batch_size, hidden_dim]
  • Transformer Output (hs): [num_queries, batch_size, hidden_dim]
  • Permuted Transformer Output: [batch_size, num_queries, hidden_dim]
  • Predicted Boxes: [batch_size, num_queries, 4]
  • Predicted Logits: [batch_size, num_queries, num_classes]

Conclusion

In this tutorial, we’ve explored the GLIP model in detail, including:

  • The integration of visual and textual information.
  • The architectural components and their functions.
  • How text embeddings are incorporated into object queries.
  • The training and inference processes with detailed code examples.

GLIP represents a significant advancement in unifying vision and language tasks, enabling more sophisticated and context-aware object detection systems.


Additional Notes

  • Extensibility: The model can be extended to handle more complex language inputs or larger vocabularies.
  • Performance Optimization: Techniques like mixed-precision training and distributed computing can improve training efficiency.
  • Data Requirements: Requires datasets that contain both visual and textual annotations.

By following this tutorial, you should have a solid understanding of how GLIP works and how to implement it for your own applications.