You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpt2_optics/src/train_char_gpt2_koef_128.py

367 lines
14 KiB
Python

import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")