You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpt2_optics/src/train_optics_trainable_foca...

490 lines
20 KiB
Python

import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndFocalDistLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(
h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndFocalDistLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")