import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime import sys from pathlib import Path import optical_matrix_multiplication as omm import shutil seed = 1337 torch.manual_seed(seed) batch_size = 50 gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = int(4e4) #40000 eval_interval = 300 learning_rate = 1e-3 device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 max_seq_len = 512 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 assert batch_size % gradient_accumulation_steps == 0 ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 device = tensor_1.device A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений max_B_pos = torch.max(B_pos) max_B_neg = torch.max(B_neg) zero_template = torch.zeros_like( torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) if max_A_pos > 0 and max_B_pos > 0: t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos else: t1 = zero_template.clone().to(device) if max_A_pos > 0 and max_B_neg > 0: t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg else: t2 = zero_template.clone().to(device) if max_A_neg > 0 and max_B_pos > 0: t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos else: t3 = zero_template.clone().to(device) if max_A_neg > 0 and max_B_neg > 0: t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg else: t4 = zero_template.clone().to(device) return (t1 - t2 - t3 + t4)[0,:,:,:] # RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 class RoPE(nn.Module): def __init__(self, dim, max_seq_len=512): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_seq_len).type_as(inv_freq) freqs = torch.einsum('i,j->ij', t, inv_freq) self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) def rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def forward(self, x, offset=0): seq_len = x.size(1) emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) cos = emb.cos() sin = emb.sin() return (x * cos) + (self.rotate_half(x) * sin) # Transformers without Normalization https://jiachenzhu.github.io/DyT/ class DyT(nn.Module): def __init__(self, num_features, alpha_init_value=0.5): super().__init__() self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) def forward(self, x): x = torch.tanh(self.alpha * x) return x * self.weight + self.bias # Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 # NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 class TransformerLayer(nn.Module): def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.q_proj = nn.Linear(h_dim, h_dim) self.k_proj = nn.Linear(h_dim, h_dim) self.v_proj = nn.Linear(h_dim, h_dim) self.o_proj = nn.Linear(h_dim, h_dim) self.ff1 = nn.Linear(h_dim, 4*h_dim) self.ff2 = nn.Linear(4*h_dim, h_dim) self.ln1 = DyT(h_dim) self.ln2 = DyT(h_dim) self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) # now trainable parameters are in optical mat mul class, one scalar per simulator # here we use only TrainableLensOpticalMul and scalars for each layer self.k1 = nn.Parameter(torch.randn(1)) self.k2 = nn.Parameter(torch.randn(1)) def split_to_heads(self, x, B, T, H): if self.num_heads <= 1: return x return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) def gather_heads(self, x, B, T, H): if self.num_heads <= 1: return x return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) def attention(self, x): q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) v = self.split_to_heads(self.v_proj(x), *x.shape) scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) scores = scores.masked_fill(tril == 0, float('-inf')) attention = nn.functional.softmax(scores, dim=2) output = self.k2 * new_formula(self.sim_output, attention, v) return self.o_proj(self.gather_heads(output, *x.shape)) def forward(self, x): x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x class OpticGPT2TrainableScalarAndLens(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) if max_seq_len != 512: self.sim_scores = omm.TrainableLensOpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, right_matrix_width = pixel_size * max_seq_len, right_matrix_height = pixel_size * (h_dim // num_heads), min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.01) ) self.sim_output = omm.TrainableLensOpticalMul( omm.Config(right_matrix_count_columns = h_dim // num_heads, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * (h_dim // num_heads), right_matrix_height = pixel_size * max_seq_len, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.01) ) if max_seq_len == 512: self.sim_scores = omm.TrainableLensOpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, right_matrix_width = pixel_size * max_seq_len, right_matrix_height = pixel_size * (h_dim // num_heads), min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.15, lens_size = 8192 * 2) ) self.sim_output = omm.TrainableLensOpticalMul( omm.Config(right_matrix_count_columns = h_dim // num_heads, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * (h_dim // num_heads), right_matrix_height = pixel_size * max_seq_len, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.15, lens_size = 8192 * 2) ) self.sim_scores = omm.ScatterDataParallel(self.sim_scores) self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) self.tok_embeds = nn.Embedding(vocab_size, h_dim) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) self.lm_head = nn.Linear(h_dim, vocab_size) def forward(self, x, targets=None): x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] for l in self.layers: x = l(x) logits = self.lm_head(x) # B,T,C loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None return logits, loss # what is the purpose? autoregressive inference! def generate(self, start_idx, max_new_tokens): idx = start_idx for i in range(max_new_tokens): idx_cond = idx[:,-self.max_seq_len:] logits, loss = self(idx_cond) logits = logits[:,-1,:] # B, C probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) idx = torch.cat([idx, idx_next], dim=1) return idx def log_trainable_optic_params( self, writer: SummaryWriter, global_step, ): self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) for i, layer in enumerate(self.layers): # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) ################################################################################################### MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") test_data_path = Path("./data/wiki.test.tokens") comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) print("Logs dir:", logs_dir) # script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository script_snapshot_path.chmod(0o500) # with read-only permission # Create standalone checkpoints directory with your specified format checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) print("Checkpoints dir:", checkpoints_dir) #################################### Dataset ######################################### # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt with open(train_data_path, encoding='utf-8') as f: train_text = f.read() with open(val_data_path, encoding='utf-8') as f: val_text = f.read() with open(test_data_path, encoding='utf-8') as f: test_text = f.read() text = train_text + val_text + test_text chars = sorted(set(text)) vocab_size = len(chars) wtoi = {w:i for i,w in enumerate(chars)} itow = {i:w for i,w in enumerate(chars)} encode = lambda s: [wtoi[w] for w in s] decode = lambda idx: ''.join([itow[i] for i in idx]) def get_batch(data, seq_len, batch_size): ix = torch.randint(len(data)-seq_len, (batch_size, )) x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) return x, y train_data = torch.tensor(encode(train_text), dtype=torch.long) val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 # Precompute all valid start positions start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) total_sequences = len(start_positions) # Process sequences in batches for i in range(0, total_sequences, batch_size): batch_starts = start_positions[i:min(i + batch_size, total_sequences)] # Efficiently stack sequences into batch tensors x_batch = torch.stack([ data[start:start + max_seq_len] for start in batch_starts ]).to(device) y_batch = torch.stack([ data[start + 1:start + max_seq_len + 1] for start in batch_starts ]).to(device) # Forward pass (model should return mean loss averaged over all tokens in batch) _, mean_loss = model(x_batch, y_batch) # Accumulate weighted loss (mean_loss is already averaged over tokens) num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens # Progress update processed = min(i + batch_size, total_sequences) print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) return decode(generated_tokens[0].tolist()) m = MODEL_CLASS( vocab_size=vocab_size, h_dim=h_dim, max_seq_len=max_seq_len, num_heads=num_heads, pixel_size=pixel_size, layers_num=layers_num ) m = m.to(device) writer.add_text('model', str(m), 0) model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' writer.add_text('model', model_description, 0) #################################### Train ######################################### optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) #################################### Checkpoint Function ######################################### def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): """Save model checkpoint with complete training state""" checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' torch.save({ 'step': step, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'config': config, 'wtoi': wtoi, 'itow': itow, }, checkpoint_path) # Training config for checkpointing training_config = { 'vocab_size': vocab_size, 'layers_num': layers_num, 'h_dim': h_dim, 'max_seq_len': max_seq_len, 'num_heads': num_heads, 'dropout_rate': dropout_rate, 'batch_size': batch_size, 'learning_rate': learning_rate, 'gradient_accumulation_steps': gradient_accumulation_steps, 'pixel_size': pixel_size, 'max_iters': max_iters, } #################################### Train ######################################### start_time = datetime.now() print("Started at:", start_time) m.eval() task_prompts = [ "1 2 3 4 5", "The capital of France is", "The chemical symbol of gold is", "If yesterday was Friday, then tomorrow will be", "The opposite of hot is", "The planets of the solar system are:", "My favorite color is", "If 5*x + 3 = 13, then x is", ] completion = complete(m, encode("\n\n"), max_seq_len) print(completion) writer.add_text('completions', completion, 0) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) print(task_results) writer.add_text('completions/task', task_results, 0) for i in range(max_iters): m.train() optimizer.zero_grad(set_to_none=True) accumulated_loss = 0.0 for j in range(gradient_accumulation_steps): xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) logits, loss = m(xb, yb) loss = loss / gradient_accumulation_steps loss.backward() accumulated_loss += loss.item() if i % 100 == 0: writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) optimizer.step() writer.add_scalar('loss', accumulated_loss, i) print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) writer.add_text('completions/task', task_results, i) m.log_trainable_optic_params(writer, i) save_checkpoint( model=m, optimizer=optimizer, step=i, loss=accumulated_loss, config=training_config, wtoi=wtoi, itow=itow, checkpoint_dir=checkpoints_dir ) m.eval() ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") completion = complete(m, encode("\n\n"), max_seq_len) print(completion) writer.add_text('completions', completion, i+1) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) print(task_results) writer.add_text('completions/task', task_results, i+1) m.log_trainable_optic_params(writer, max_iters) # Save final checkpoint save_checkpoint( model=m, optimizer=optimizer, step=max_iters, loss=accumulated_loss, config=training_config, wtoi=wtoi, itow=itow, checkpoint_dir=checkpoints_dir ) print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")