A Deep Dive into Decoder-Only Transformers
Last Updated: June 8, 2025
1. The Task and Core Concept: Autoregressive Generation
The core task of a decoder-only model is autoregressive text generation. The model learns to predict the very next word given a sequence of preceding words. It generates text one token at a time, feeding its own previous output back in as input to predict the next token. This simple, self-supervised objective, when applied at a massive scale, enables the model to learn grammar, facts, reasoning abilities, and style.
2. Architecture: The Power of Masking
A decoder-only model is a stack of decoder layers from the original Transformer architecture. Its critical component is Causal Self-Attention (or Masked Self-Attention), which ensures that the prediction for a token at position i
can only depend on the known outputs at positions less than i
.
This is achieved by adding a mask to the attention score matrix before the softmax operation, effectively zeroing out the probabilities for all future tokens.
The Mathematics of Causal Self-Attention
We modify the standard attention formula by adding a mask matrix $M$:
\[M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}\] \[\text{CausalAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V\]Code Example: Generating the Causal Attention Mask
Here is a practical PyTorch snippet demonstrating how this mask is generated on the fly.
import torch
# Define the sequence length for a given batch
sequence_length = 4
# Create the mask that prevents attending to future tokens
# `torch.triu` creates an upper triangular matrix. diagonal=1 ensures the diagonal is 0.
causal_mask_base = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1).bool()
# tensor([[False, True, True, True],
# [False, False, True, True],
# [False, False, False, True],
# [False, False, False, False]])
# The final mask uses a large negative number for masked positions and 0.0 for unmasked.
final_mask = torch.zeros(sequence_length, sequence_length)
final_mask.masked_fill_(causal_mask_base, float('-inf'))
# tensor([[0., -inf, -inf, -inf],
# [0., 0., -inf, -inf],
# [0., 0., 0., -inf],
# [0., 0., 0., 0.]])
# This final_mask is then added to the attention scores before the softmax.
3. Positional Embeddings: Encoding Order
The Transformer architecture is permutation-invariant. We must explicitly inject positional information. The final input embedding for each token is the sum of its token embedding and its positional embedding.
3.1 Absolute Positional Embeddings (e.g., GPT-3)
- Mechanism: A unique vector is learned for each absolute position in the sequence (1st, 2nd, 3rd, etc.) from an embedding matrix of size
[max_sequence_length, hidden_size]
. - Limitation: This method has a hard maximum sequence length and may not generalize well beyond it.
3.2 Deep Dive: Rotary Positional Embeddings (RoPE)
Modern state-of-the-art models like Llama use a more advanced method called Rotary Positional Embeddings (RoPE).
-
Intuition: Instead of adding position information, RoPE rotates our Query and Key vectors based on their position. The key insight is that the dot product between two vectors is related to the cosine of the angle between them. If we rotate our Query vector by an angle
m
and our Key vector by an anglen
, the dot product between the new vectors will depend on the relative anglem-n
. By making the rotation angle a function of a token’s position, the attention score naturally becomes dependent on the relative positions of the tokens, not their absolute ones. - Mathematical Formulation:
- First, we view our
d
-dimensional embedding vectors (Queries and Keys) as a sequence ofd/2
two-dimensional vectors. Let’s take one such pair $(x_{2i}, x_{2i+1})$ from a vector $x$ at position $m$. - We define a rotation angle $\theta_{m,i} = m \cdot 10000^{-2i/d}$. The angle depends on both the token’s position
m
and the feature indexi
. Different dimensions are rotated at different frequencies. - We apply a 2D rotation matrix to this pair: \(R_m = \begin{pmatrix} \cos(\theta_{m,i}) & -\sin(\theta_{m,i}) \\ \sin(\theta_{m,i}) & \cos(\theta_{m,i}) \end{pmatrix}\) The positionally-encoded feature pair is then computed as $R_m \begin{pmatrix} x_{2i} \ x_{2i+1} \end{pmatrix}$.
- This is done for both the Query vector $q_m$ at position
m
and the Key vector $k_n$ at positionn
. The crucial property is that the dot product of the rotated vectors, $\langle R_m q_m, R_n k_n \rangle$, can be shown through trigonometric identities to be a function of their original dot product and their relative position, $m-n$.
- First, we view our
- Benefits: This elegant method provides robust relative position information and has shown excellent performance when generalizing to sequence lengths far beyond what the model saw during training.
4. Training: The Next-Token Prediction Objective
4.1 Tokenization and Input-Output Pairs
Using a sub-word tokenizer like BPE, the training data is created by simply shifting the tokenized text sequence by one position.
- Input (
X
):[The, quick, brown, fox, jumps]
- Label (
Y
):[quick, brown, fox, jumps, .]
4.2 Loss Function: Cross-Entropy
The model is trained to minimize the Cross-Entropy Loss between its predicted probability distribution for the next token and the one-hot encoded vector of the true next token. For a single sequence $X = (x_1, …, x_S)$, the loss is the average negative log-likelihood of the true next token at each position:
\[L_{CE}(\theta) = - \frac{1}{S} \sum_{t=1}^{S} \log P(x_t | x_{<t}; \theta)\]4.3 Evaluation Metric: Perplexity (and its relation to Loss)
While we use Cross-Entropy as the loss function to be minimized during training, we often use Perplexity (PPL) as the primary metric for evaluating and comparing language models.
-
Relationship: Perplexity is the direct exponential of the cross-entropy loss.
\[\text{Perplexity} = e^{L_{CE}} = \exp\left(- \frac{1}{S} \sum_{t=1}^{S} \log P(x_t | x_{<t}; \theta)\right)\] - Intuition: Why “Perplexity”? It is a measure of how “perplexed” or “surprised” the model is by the test data. A lower perplexity indicates that the model is less surprised, meaning it consistently assigns a higher probability to the correct next token. A perplexity of 100 can be interpreted as the model being, on average, as uncertain about the next word as if it were choosing uniformly from 100 different words. A perfect model would have a perplexity of 1.
- In Summary: We train the model by minimizing cross-entropy loss. We evaluate and compare models by reporting their perplexity. Lower is better for both.
5. Prominent Architectures: GPT and Llama Families
The GPT (Generative Pre-trained Transformer) Family
Developed by OpenAI, this family pioneered the scaling of decoder-only models, trained on datasets like BookCorpus and WebText, culminating in GPT-3’s demonstration of in-context learning.
The Llama (Large Language Model Meta AI) Family
Developed by Meta AI, the Llama family focused on training efficiency, demonstrating that smaller models trained on vast amounts of data (1-2 Trillion tokens from sources like C4) could outperform larger models. Their open release spurred massive innovation, and Llama 2 incorporated RLHF for safety.
Key Differences: Llama vs. GPT
| Feature | GPT Family (e.g., GPT-3) | Llama Family (e.g., Llama 2) | | :— | :— | :— | | Access | Closed-source, accessible via API. | Open-source (weights available), allowing local deployment and research. | | Key Innovation Focus | Scaling parameters to achieve emergent abilities (in-context learning). | Training efficiency (more tokens for smaller models), open access, and RLHF for safety. | | Architectural Tweaks | Standard Transformer decoder. | Pre-normalization (RMSNorm): Normalizes inputs to layers for better training stability. | | Activation Function | ReLU | SwiGLU: A variant of the Gated Linear Unit, which has shown better performance than ReLU. | | Positional Embeddings | Learned, absolute positional embeddings. | Rotary Positional Embeddings (RoPE): Encodes relative position, improving long-sequence performance. |
6. Fine-Tuning and Inference
- Fine-Tuning: A “base model” is further trained on a curated dataset of
(instruction, desired_output)
pairs to align its behavior. - Inference: Generating text is an iterative process. The strategy we use to select a token from the model’s probability distribution is critical.
6.1 A Deep Dive into Sampling Strategies
Of course, professor. That is an excellent point. A high-level description is insufficient for a graduate-level understanding. The precise mechanics of how a token is selected from a probability distribution are fundamental to the behavior of these models.
Let’s dedicate a detailed section to this, complete with procedural explanations and code snippets. I will add this to our definitive tutorial.
A Deep Dive into Inference: How the Next Token is Picked
Once the decoder model has completed its forward pass, it outputs a vector of logits for the next token. This vector has a size equal to the entire vocabulary (e.g., [1, 50257]
for GPT-2). To be useful, these raw scores must be converted into a probability distribution.
The Starting Point: The Softmax Function
First, we apply the softmax
function to the logits to get a probability distribution P
where every value is between 0 and 1, and all values sum to 1.
At this point, we have a probability for every single word in the vocabulary being the next word. The question is: how do we choose one?
Let’s explore the common strategies, from simplest to most sophisticated.
1. Greedy Search (Deterministic)
This is the most straightforward method.
- Mechanism: Simply select the token with the highest probability. This is equivalent to performing an
argmax
on the logits vector. - Intuition: “Always choose what the model thinks is best.”
- Shortcoming: While safe, it’s extremely boring and repetitive. The model will often get stuck in loops of commonly associated words, lacking any creativity or variation. It is almost never used for creative text generation.
2. Temperature Sampling
This is not a selection strategy on its own, but a modifier that affects all other sampling methods. It allows us to control the “randomness” of the model’s predictions.
-
Mechanism: We rescale the logits before applying the softmax function using a
\[P = \text{softmax}\left(\frac{\text{logits}}{T}\right)\]temperature
parameter, $T$. -
Intuition:
- $T > 1$ (e.g., 1.2): Flattens the distribution. This makes less likely tokens more probable, increasing randomness. The model becomes more “creative” and “daring.”
- $T < 1$ (e.g., 0.7): Sharpens the distribution. This makes high-probability tokens even more likely, reducing randomness. The model becomes more “focused” and “confident.”
3. Top-k Sampling
This is the first true sampling strategy that addresses the problem of Greedy Search.
- Intuition: “Don’t even consider the crazy, low-probability options. Only choose from a fixed-size ‘shortlist’ of the most likely candidates.”
- Mechanism:
- Identify the
k
tokens with the highest probabilities from the full distributionP
. - Set the probability of all other tokens to zero.
- Re-normalize the probabilities of the top
k
tokens so they sum to 1. - Sample from this new, smaller distribution.
- Identify the
-
Problem: A fixed
k
is not adaptive. In some contexts, the number of reasonable next words might be very large (e.g., at the start of a story), and in others very small (e.g., after “The Eiffel Tower is in…”). Top-k struggles with this dynamic. - Code Snippet for Top-k:
import torch import torch.nn.functional as F def top_k_sampling(logits, k=50): # logits shape: [batch_size, vocab_size] # 1. Find the top k logits and their values # The `topk` function returns both values and indices top_k_values, top_k_indices = torch.topk(logits, k) # [B, k] # 2. Create a new logits tensor filled with -inf filtered_logits = torch.full_like(logits, float('-inf')) # 3. Use `scatter` to place the top k values back into the filtered logits # This effectively sets all non-top-k logits to -inf filtered_logits.scatter_(1, top_k_indices, top_k_values) # 4. Apply softmax to get a re-normalized probability distribution probabilities = F.softmax(filtered_logits, dim=-1) # 5. Sample from this new distribution next_token = torch.multinomial(probabilities, num_samples=1) return next_token
4. Top-p (Nucleus) Sampling
This is the most popular and effective sampling method, as it creates an adaptive shortlist.
- Intuition: “Instead of a fixed-size list, choose from the smallest possible list of candidates whose combined probability meets a certain threshold.”
- Mechanism:
- Sort the vocabulary tokens by their probability in descending order.
- Calculate the cumulative sum of these sorted probabilities.
- Find all tokens whose cumulative probability is greater than the threshold
p
(e.g.,p=0.95
). These tokens are removed from the candidate list. The remaining set is the “nucleus.” - Re-normalize the probabilities of the nucleus tokens so they sum to 1.
- Sample from this adaptive nucleus distribution.
Of course. That’s an excellent request, as the tensor manipulations in Top-p sampling can be complex. Understanding the shapes at each step is key to grasping the algorithm.
Let’s break down the code with detailed inline comments explaining the tensor sizes. We will assume a batch_size
(B) of 4 and a vocab_size
(V) of 50,000 for this example.
Top-p (Nucleus) Sampling with Tensor Size Explanations
import torch
import torch.nn.functional as F
def top_p_sampling(logits, p=0.95):
# Assume:
B = 4 # batch_size
V = 50000 # vocab_size
# logits shape: [B, V] -> e.g., [4, 50000]
# This is the raw output from the language model head for the last token.
# 1. Sort logits in descending order to easily find the nucleus of high-probability tokens.
# We need both the sorted values and their original indices to reconstruct the filter later.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# sorted_logits shape: [B, V] -> e.g., [4, 50000]
# sorted_indices shape: [B, V] -> e.g., [4, 50000]
# Convert sorted logits to probabilities
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
# sorted_probabilities shape: [B, V] -> e.g., [4, 50000]
# Example for one row: [0.1, 0.08, 0.05, 0.02, ..., 0.00001]
# 2. Calculate the cumulative sum of the probabilities.
cumulative_probs = torch.cumsum(sorted_probabilities, dim=-1)
# cumulative_probs shape: [B, V] -> e.g., [4, 50000]
# Example for one row: [0.1, 0.18, 0.23, 0.25, ..., 1.0]
# 3. Create a mask of tokens to remove. These are tokens that are NOT in the nucleus.
# The nucleus is the smallest set of tokens whose cumulative probability is >= p.
# So, we find all tokens where the cumulative probability already exceeds p.
sorted_indices_to_remove = cumulative_probs > p
# sorted_indices_to_remove shape: [B, V], dtype=torch.bool
# Example for one row (if p=0.9): [False, False, ..., True, True, True]
# 4. Shift the mask to the right to ensure we keep the first token that pushes the
# cumulative probability over the threshold p.
# We shift all elements one to the right, and the first element becomes False (0).
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Now the mask correctly identifies only tokens that are truly outside the nucleus.
# 5. Go from the sorted view back to the original vocabulary order.
# We create a boolean mask of the same shape as the original logits.
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
# We use `scatter_` to place `True` values at the original positions of the tokens we want to remove.
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
# indices_to_remove shape: [B, V], dtype=torch.bool
# 6. Apply the mask to the original logits.
# `masked_fill_` sets the value of logits to -inf wherever the mask is True.
filtered_logits = logits.masked_fill(indices_to_remove, float('-inf'))
# filtered_logits shape: [B, V] -> e.g., [4, 50000]
# Now contains the original logit values for nucleus tokens, and -inf for all others.
# 7. Apply softmax to the filtered logits to get re-normalized probabilities.
# The -inf values will become 0 after softmax, and the probabilities of the
# nucleus tokens will be re-distributed to sum to 1.
probabilities = F.softmax(filtered_logits, dim=-1)
# probabilities shape: [B, V] -> e.g., [4, 50000]
# 8. Sample one token from this new, filtered distribution.
# `torch.multinomial` performs a weighted random draw.
next_token = torch.multinomial(probabilities, num_samples=1)
# next_token shape: [B, 1] -> e.g., [4, 1]
return next_token
Step-by-Step Walkthrough
Let’s trace a single sequence (B=1
) with a tiny vocabulary (V=10
) and p=0.9
to make it concrete.
-
Start with Logits:
logits
=[1.2, 3.1, 0.5, 8.2, -1.0, 5.5, 6.1, 0.1, 2.5, 4.3]
(Shape:[1, 10]
) -
Sort Logits: We get the sorted values and their original indices.
sorted_logits
=[8.2, 6.1, 5.5, 4.3, 3.1, 2.5, 1.2, 0.5, 0.1, -1.0]
sorted_indices
=[3, 6, 5, 9, 1, 8, 0, 7, 2, 4]
-
Get Sorted Probabilities (after softmax):
sorted_probabilities
=[0.60, 0.18, 0.10, 0.04, 0.01, ..., ]
-
Get Cumulative Probabilities:
cumulative_probs
=[0.60, 0.78, 0.88, 0.92, 0.93, ..., 1.0]
-
Find Indices to Remove: Find where
cumulative_probs > p
(wherep=0.9
).sorted_indices_to_remove
(initial) =[F, F, F, T, T, T, T, T, T, T]
-
Shift the Mask:
sorted_indices_to_remove
(shifted) =[F, F, F, F, T, T, T, T, T, T]
This is the key step. We keep the token that pushed us over thep
threshold (the 4th one, with probability 0.04). The nucleus now consists of the first 4 tokens. -
Map Mask to Original Indices: We use
sorted_indices
to putTrue
(remove) at the correct original positions. We will remove all tokens except those at original indices3, 6, 5, 9
. -
Filter Logits: The original
logits
tensor has the scores for tokens outside the nucleus set to-inf
. -
Re-normalize with Softmax: We apply a final softmax. Only the 4 tokens in the nucleus will have a non-zero probability, and their probabilities will sum to 1.
-
Sample:
torch.multinomial
picks one token from this final 4-element distribution. The result is a single token ID. In practice, libraries like Hugging Facetransformers
combine all these techniques into a single.generate()
method where you can specifytemperature
,top_k
, andtop_p
simultaneously, allowing for fine-grained control over the generation process. —7. From-Scratch Implementation of a Decoder-Only Model
This implementation uses the computationally efficient, fused-layer approach for multi-head attention.
import torch
import torch.nn as nn
from torch.nn import functional as F
# --- Hyperparameters for our Toy Model ---
batch_size = 32 # How many independent sequences will we process in parallel?
block_size = 128 # What is the maximum context length for predictions?
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 256 # Embedding dimension
n_head = 4 # Number of attention heads
n_layer = 1 # Number of decoder layers (for simplicity)
dropout = 0.2
eval_interval = 500
max_steps = 5001
# --- 1. Data Preparation and Tokenizer ---
# For this toy example, we'll use a simple character-level tokenizer.
try:
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
except FileNotFoundError:
print("Warning: 'input.txt' not found. Using dummy text for demonstration.")
text = "hello world, this is a demonstration of a toy gpt model for a top-tier graduate school class."
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
def get_batch(split):
data = train_data if split == 'train' else val_data
# Generate random starting points for each sequence in the batch
ix = torch.randint(len(data) - block_size, (batch_size,))
# Input sequences (the context)
x = torch.stack([data[i:i+block_size] for i in ix])
# Target sequences (the next character to predict)
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x.to(device), y.to(device)
# --- 2. Model Components (Optimized Implementation) ---
class MultiHeadAttention(nn.Module):
""" The efficient, fused implementation of Multi-Head Causal Self-Attention """
def __init__(self, n_embd, num_heads):
super().__init__()
self.num_heads = num_heads
assert n_embd % num_heads == 0
self.head_size = n_embd // num_heads
# A single, fused linear layer for Q, K, V projections
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
# The final output projection layer
self.c_proj = nn.Linear(n_embd, n_embd)
self.resid_dropout = nn.Dropout(dropout)
# Causal mask buffer
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
# Input x shape: [B, T, C] (Batch, Time/seq_len, Channels/n_embd)
B, T, C = x.shape
# 1. Fused Projection & Splitting
qkv = self.c_attn(x) # Shape: [B, T, 3 * C]
q, k, v = qkv.split(self.n_embd, dim=2) # Each is [B, T, C]
# 2. Reshape for Multi-Head computation
# (B, T, C) -> (B, T, num_heads, head_size) -> (B, num_heads, T, head_size)
q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
# 3. Batched Scaled Dot-Product Attention
# (B, num_heads, T, head_size) @ (B, num_heads, head_size, T) -> (B, num_heads, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / self.head_size**0.5)
att = att.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf')) # Apply causal mask
att = F.softmax(att, dim=-1)
# (B, num_heads, T, T) @ (B, num_heads, T, head_size) -> (B, num_heads, T, head_size)
y = att @ v
# 4. Concatenate and Project back
y = y.transpose(1, 2).contiguous().view(B, T, C) # Shape: [B, T, C]
return self.resid_dropout(self.c_proj(y))
class FeedForward(nn.Module):
""" A simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd), nn.ReLU(),
nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
class DecoderBlock(nn.Module):
""" A single Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head):
super().__init__()
self.sa = MultiHeadAttention(n_embd, n_head)
self.ffwd = FeedForward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
# Pre-normalization and residual connections
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
# --- 3. Full Model ---
class SimpleDecoderOnlyModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[DecoderBlock(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # Final layer norm
self.lm_head = nn.Linear(n_embd, vocab_size) # Language model head
def forward(self, idx, targets=None):
B, T = idx.shape
# Get token and position embeddings
# idx shape: [B, T]
tok_emb = self.token_embedding_table(idx) # Shape: [B, T, C]
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # Shape: [T, C]
x = tok_emb + pos_emb # Shape: [B, T, C]
# Pass through decoder blocks
x = self.blocks(x) # Shape: [B, T, C]
x = self.ln_f(x) # Shape: [B, T, C]
# Final projection to vocabulary size
logits = self.lm_head(x) # Shape: [B, T, vocab_size]
# Calculate loss if targets are provided
loss = None
if targets is not None:
B, T, C = logits.shape
logits_for_loss = logits.view(B*T, C)
targets_for_loss = targets.view(B*T)
loss = F.cross_entropy(logits_for_loss, targets_for_loss)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop context to the last block_size tokens to respect positional embedding limits
idx_cond = idx[:, -block_size:]
# Get the predictions
logits, loss = self(idx_cond)
# Focus only on the logit for the last time step
logits = logits[:, -1, :] # Becomes (B, C)
# Apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
# --- 4. Training Loop ---
model = SimpleDecoderOnlyModel()
m = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print("Starting training...")
for steps in range(max_steps):
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if steps % eval_interval == 0:
print(f"Step {steps}, Training Loss: {loss.item():.4f}")
# --- 5. Generation from the model ---
print("\n--- Generating Text from Trained Model ---")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_tokens = m.generate(context, max_new_tokens=200)[0].tolist()
print(decode(generated_tokens))