diff --git a/src/bpe_main.py b/src/bpe_main.py deleted file mode 100644 index 6bd30e5..0000000 --- a/src/bpe_main.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -from char_gpt2 import GPT2 -from optics_char_gpt2 import OpticGPT2 -from bpe_tokenizer import byte_pair_init, byte_pair_encode, byte_pair_decode - -seed = 1337 -torch.manual_seed(seed) -models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2} - -batch_size = 50 -max_iters = 40000*10 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 22 -h_dim = 64 -max_seq_len = 256 -num_heads = 4 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -merges_count = 20 - -# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment -MODEL_CLASS = models[sys.argv[1]] -train_data_path = Path(sys.argv[2]) -val_data_path = Path(sys.argv[3]) -test_data_path = Path(sys.argv[4]) -comment = f"bpe_{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}" - -logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' -writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) -print("Logs dir:", logs_dir) -script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -script_snapshot_path.chmod(0o400) # with read-only permission -#################################### Dataset ######################################### - -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -with open(train_data_path, encoding='utf-8') as f: - train_text = f.read() - -with open(val_data_path, encoding='utf-8') as f: - val_text = f.read() - -with open(test_data_path, encoding='utf-8') as f: - test_text = f.read() - -text = train_text + val_text + test_text - -chars = sorted(set(text)) -print(f"Len chars: {len(chars)}") - -wtoi = {w:i for i,w in enumerate(chars)} -itow = {i:w for i,w in enumerate(chars)} - -import pickle - -start_time = datetime.now() -if Path("./data/bpe_text.pkl").exists() and Path("./data/merges.pkl").exists(): - with open("./data/bpe_text.pkl", 'rb') as f: bpe_text = pickle.load(f) - with open("./data/merges.pkl", 'rb') as f: merges = pickle.load(f) -else: - bpe_text, merges = byte_pair_init([wtoi[w] for w in text], vocab_size=len(chars), merges_count=20) - with open("./data/bpe_text.pkl", 'wb') as f: pickle.dump(bpe_text, f) - with open("./data/merges.pkl", 'wb') as f: pickle.dump(merges, f) - -print(f"Compression ratio: {len(text)/len(bpe_text)}, init took {datetime.now()-start_time}") - -vocab_size = len(chars) + merges_count - -encode = lambda s: byte_pair_encode([wtoi[w] for w in s], merges) -decode = lambda idx: "".join([itow[i] for i in byte_pair_decode(idx, merges)]) - -def get_batch(data, seq_len, batch_size): - ix = torch.randint(len(data)-seq_len, (batch_size, )) - x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) - y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) - return x, y - -start_time = datetime.now() -train_bpe_encoded_path = Path("./data/train_bpe_encoded.pt") -val_bpe_encoded_path = Path("./data/val_bpe_encoded.pt") -test_bpe_encoded_path = Path("./data/test_bpe_encoded.pt") -if train_bpe_encoded_path.exists() and val_bpe_encoded_path.exists() and test_bpe_encoded_path.exists(): - train_data = torch.load(train_bpe_encoded_path) - val_data = torch.load(val_bpe_encoded_path) - test_data = torch.load(test_bpe_encoded_path) -else: - train_data = torch.tensor(encode(train_text), dtype=torch.long) - val_data = torch.tensor(encode(val_text), dtype=torch.long) - test_data = torch.tensor(encode(test_text), dtype=torch.long) - torch.save(train_data, train_bpe_encoded_path) - torch.save(val_data, val_bpe_encoded_path) - torch.save(test_data, test_bpe_encoded_path) -print(f"Encoded {datetime.now() - start_time}") - -@torch.no_grad() -def perplexity(model, data): - stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo -def complete(m, start_idxs=[0], max_new_tokens=100): - start_idx = torch.tensor([start_idxs]).to(device) - generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) - return decode(generated_tokens[0].tolist()) - -m = MODEL_CLASS( - vocab_size=vocab_size, - h_dim=h_dim, - max_seq_len=max_seq_len, - num_heads=num_heads, - pixel_size=pixel_size, - layers_num=layers_num -) -m = m.to(device) -print(m) -#################################### Train ######################################### - -optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, 0) - -for i in range(max_iters): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size) - logits, loss = m(xb, yb) - optimizer.zero_grad(set_to_none=True) - loss.backward() - optimizer.step() - writer.add_scalar('loss', loss.item(), i) - print(f"\r{i}/{max_iters} {loss.item()}", end="") - if i % 5000 == 0: - ppl = perplexity(model=m, data=val_data) - writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\rPerplexity at {i}: {ppl}") - writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) - -ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {loss.item()}") -print(f"\rPerplexity at {i}: {ppl}") -writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', loss.item(), i) - -ppl = perplexity(model=m, data=test_data) -writer.add_scalar('test_perplexity', ppl.item(), i+1) -print(f"\rTest Perplexity at {i}: {ppl}") - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, i+1) \ No newline at end of file diff --git a/src/bpe_tokenizer.py b/src/bpe_tokenizer.py deleted file mode 100644 index d66748a..0000000 --- a/src/bpe_tokenizer.py +++ /dev/null @@ -1,116 +0,0 @@ -import numpy as np -import pandas as pd - -def get_top_pair(tokens): - hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.int32) - return list(hist.nlargest(1, columns='count').iloc[0, [0,1]]) - -def merge(tokens, pair, new_idx): - new_tokens = [] - skip = False - for a,b in zip(tokens[:-1], tokens[1:]): - if skip: - skip = False - continue - if a == pair[0] and b == pair[1]: - new_tokens.append(new_idx) - skip = True - else: - new_tokens.append(a) - if not skip: - new_tokens.append(b) - return np.array(new_tokens) - -def unmerge(tokens, pair_idx, pair): - new_tokens = [] - for idx in tokens: - if idx == pair_idx: - new_tokens.append(pair[0]) - new_tokens.append(pair[1]) - else: - new_tokens.append(idx) - return new_tokens - -def byte_pair_init(char_ids, vocab_size, merges_count=20): - byte_text = np.array(char_ids) - merges = [] - for i in range(merges_count): - top_pair = get_top_pair(byte_text) - new_idx = vocab_size + i - merges.append([top_pair, new_idx]) - print(f"{top_pair} {new_idx}") - byte_text = merge(byte_text, top_pair, new_idx) - return np.array(byte_text), merges - -def byte_pair_encode(char_ids, merges): - tokens = np.array(char_ids) - for pair, pair_idx in merges: - tokens = merge(tokens, pair, pair_idx) - return tokens - -def byte_pair_decode(tokens, merges): - for pair, pair_idx in reversed(merges): - tokens = unmerge(tokens, pair_idx, pair) - return tokens - - -# def get_top_pair(tokens): -# hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.uint16) -# return np.array(hist.nlargest(1, columns='count').iloc[0, [0,1]]) - -# def merge(tokens, pair, new_idx): -# if len(tokens) % 2 != 0: -# tokens = np.append(tokens, np.array([0], dtype=np.uint16)) -# # print("not even") -# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy() -# b = np.frombuffer(bytes(pair), dtype=np.uint32) -# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32) -# a[a==b] = c -# d = np.frombuffer(bytes(a), dtype=np.uint16) -# indices = np.where(d == 2**16-1) -# e = np.delete(d, indices) -# e = e[:-1] -# else: -# # print("even") -# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy() -# b = np.frombuffer(bytes(pair), dtype=np.uint32) -# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32) -# a[a==b] = c -# d = np.frombuffer(bytes(a), dtype=np.uint16) -# indices = np.where(d == 2**16-1) -# e = np.delete(d, indices) -# return e - -# def unmerge(tokens, pair_idx, pair): -# new_tokens = [] -# for idx in tokens: -# if idx == pair_idx: -# new_tokens.append(pair[0]) -# new_tokens.append(pair[1]) -# else: -# new_tokens.append(idx) -# return new_tokens - -# def byte_pair_init(char_ids, vocab_size, merges_count=20): -# assert vocab_size < 2**16 -# byte_text = np.array(char_ids, dtype=np.uint16) -# merges = [] -# for i in range(merges_count): -# top_pair = get_top_pair(byte_text) -# new_idx = vocab_size + i -# print([top_pair, new_idx]) -# merges.append([top_pair, new_idx]) -# byte_text = merge(byte_text, top_pair, new_idx) -# byte_text = np.roll(merge(np.roll(byte_text, 1), top_pair, new_idx), -1) -# return byte_text, merges - -# def byte_pair_encode(char_ids, merges): -# tokens = np.array(char_ids, dtype=np.uint16) -# for pair, pair_idx in merges: -# tokens = merge(tokens, pair, pair_idx) -# return tokens - -# def byte_pair_decode(tokens, merges): -# for pair, pair_idx in reversed(merges): -# tokens = unmerge(tokens, pair_idx, pair) -# return tokens \ No newline at end of file diff --git a/src/char_gpt2.py b/src/char_gpt2.py deleted file mode 100644 index c1cd2cb..0000000 --- a/src/char_gpt2.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -torch.manual_seed(1337) - -#################################### Model ######################################### - - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line - attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class GPT2(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/main.py b/src/main.py deleted file mode 100644 index d6a06c4..0000000 --- a/src/main.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -from char_gpt2 import GPT2 -from optics_char_gpt2 import OpticGPT2 -from optics_char_gpt2_traindiag import OpticGPT2TrainDiag -from optics_char_gpt2_ff import OpticGPT2FF -from optics_char_gpt2_new_formula import OpticGPT2NewFormula -from char_gpt2_scaledmatmul import GPT2ScaledMM -from optics_char_gpt2_nokoef import OpticGPT2NOKoef -from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF -import shutil -seed = 1337 -torch.manual_seed(seed) -models = { - 'gpt2': GPT2, - 'optic_gpt2': OpticGPT2, - 'optic_gpt2_ff': OpticGPT2FF, - 'optic_gpt2_traindiag': OpticGPT2TrainDiag, - 'optic_gpt2_newformula': OpticGPT2NewFormula, - 'optic_gpt2_nokoef': OpticGPT2NOKoef, - 'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF, - 'gpt2_scaledmm': GPT2ScaledMM -} - -batch_size = 50 -gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = 40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 - -MODEL_CLASS = models[sys.argv[1]] -train_data_path = Path(sys.argv[2]) -val_data_path = Path(sys.argv[3]) -test_data_path = Path(sys.argv[4]) -comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}" - -logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' -writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) -print("Logs dir:", logs_dir) -# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository -script_snapshot_path.chmod(0o500) # with read-only permission - -#################################### Dataset ######################################### - -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -with open(train_data_path, encoding='utf-8') as f: - train_text = f.read() - -with open(val_data_path, encoding='utf-8') as f: - val_text = f.read() - -with open(test_data_path, encoding='utf-8') as f: - test_text = f.read() - -text = train_text + val_text + test_text - -chars = sorted(set(text)) -vocab_size = len(chars) - -wtoi = {w:i for i,w in enumerate(chars)} -itow = {i:w for i,w in enumerate(chars)} - -encode = lambda s: [wtoi[w] for w in s] -decode = lambda idx: ''.join([itow[i] for i in idx]) - -def get_batch(data, seq_len, batch_size): - ix = torch.randint(len(data)-seq_len, (batch_size, )) - x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) - y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) - return x, y - -train_data = torch.tensor(encode(train_text), dtype=torch.long) -val_data = torch.tensor(encode(val_text), dtype=torch.long) -test_data = torch.tensor(encode(test_text), dtype=torch.long) - -@torch.no_grad() -def perplexity(model, data): - stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo -def complete(m, start_idxs=[0], max_new_tokens=100): - start_idx = torch.tensor([start_idxs]).to(device) - generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) - return decode(generated_tokens[0].tolist()) - -m = MODEL_CLASS( - vocab_size=vocab_size, - h_dim=h_dim, - max_seq_len=max_seq_len, - num_heads=num_heads, - pixel_size=pixel_size, - layers_num=layers_num -) -m = m.to(device) -writer.add_text('model', str(m), 0) - -#################################### Train ######################################### - -optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) - -m.eval() -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, 0) - -for i in range(max_iters): - m.train() - optimizer.zero_grad(set_to_none=True) - accumulated_loss = 0.0 - for j in range(gradient_accumulation_steps): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) - logits, loss = m(xb, yb) - loss = loss / gradient_accumulation_steps - loss.backward() - accumulated_loss += loss.item() - if i % 100 == 0: - writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) - optimizer.step() - writer.add_scalar('loss', accumulated_loss, i) - print(f"\r{i}/{max_iters} {accumulated_loss}", end="") - if i % 5000 == 0: - m.eval() - ppl = perplexity(model=m, data=val_data) - writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") - writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) - -m.eval() -ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {accumulated_loss}") -print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") -writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', accumulated_loss, i) - -ppl = perplexity(model=m, data=test_data) -writer.add_scalar('test_perplexity', ppl.item(), i+1) -print(f"\rTest Perplexity at {i}: {ppl}") - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, i+1) \ No newline at end of file diff --git a/src/optics_char_gpt2.py b/src/optics_char_gpt2.py deleted file mode 100644 index 3989ed6..0000000 --- a/src/optics_char_gpt2.py +++ /dev/null @@ -1,211 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - self.k1 = nn.Parameter(torch.randn(1)) - self.k2 = nn.Parameter(torch.randn(1)) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_ff.py b/src/optics_char_gpt2_ff.py deleted file mode 100644 index 59bbd93..0000000 --- a/src/optics_char_gpt2_ff.py +++ /dev/null @@ -1,217 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F, init -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -import math -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - # print(tensor_1.shape, tensor_2.shape) - - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,None,:,:] - # print(tensor_1.shape, tensor_2.shape) - # raise RuntimeError - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -class OpticLinear(nn.Module): - def __init__( - self, - in_features, - out_features, - bias = True, - device = None, - dtype = None, - pixel_size = 3.6e-6 - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - torch.empty((in_features, out_features), **factory_kwargs) - ) - if bias: - self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter("bias", None) - self.k = nn.Parameter(torch.randn(1)) - self.sim = omm.OpticalMul( - omm.Config( - right_matrix_count_columns = out_features , - right_matrix_count_rows = in_features, - right_matrix_width = pixel_size * out_features , - right_matrix_height = pixel_size * in_features, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01) - ) - self.reset_parameters() - - def forward(self, input): - """ - Runs the forward pass. - """ - return self.k * optics_matmul_shift(self.sim, input, self.weight) + self.bias - - def reset_parameters(self) -> None: - """ - Resets parameters based on their initialization used in ``__init__``. - """ - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(self.bias, -bound, bound) - - def extra_repr(self) -> str: - """ - Return the extra representation of the module. - """ - return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" - - - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = OpticLinear(h_dim, 4*h_dim) - self.ff2 = OpticLinear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line - attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2FF(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py deleted file mode 100644 index 0cd65d3..0000000 --- a/src/optics_char_gpt2_nokoef.py +++ /dev/null @@ -1,211 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2NOKoef(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_scores = omm.ScatterDataParallel(self.sim_scores) - self.sim_output = omm.ScatterDataParallel(self.sim_output) - - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_traindiag.py b/src/optics_char_gpt2_traindiag.py deleted file mode 100644 index 26de9d1..0000000 --- a/src/optics_char_gpt2_traindiag.py +++ /dev/null @@ -1,210 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128, pixel_size=3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - self.k1 = nn.Parameter(torch.randn(1)) - self.k2 = nn.Parameter(torch.randn(1)) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2) - ) - self.sim_scores = omm.DataParallel(self.sim_scores) - self.sim_output = omm.DataParallel(self.sim_output) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2TrainDiag(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len, pixel_size=pixel_size) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/train_char_gpt2_128.py b/src/train_char_gpt2_128.py new file mode 100644 index 0000000..a407bee --- /dev/null +++ b/src/train_char_gpt2_128.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_256.py b/src/train_char_gpt2_256.py new file mode 100644 index 0000000..b862708 --- /dev/null +++ b/src/train_char_gpt2_256.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_gpt2.py b/src/train_char_gpt2_512.py similarity index 99% rename from src/train_gpt2.py rename to src/train_char_gpt2_512.py index 75de26f..b6d7137 100644 --- a/src/train_gpt2.py +++ b/src/train_char_gpt2_512.py @@ -11,6 +11,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# # RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 @@ -117,23 +132,6 @@ class GPT2(nn.Module): ################################################################################################### - -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py - MODEL_CLASS = GPT2 train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") diff --git a/src/train_char_gpt2_64.py b/src/train_char_gpt2_64.py new file mode 100644 index 0000000..6a12c5a --- /dev/null +++ b/src/train_char_gpt2_64.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_128.py b/src/train_char_gpt2_koef_128.py new file mode 100644 index 0000000..ee5814c --- /dev/null +++ b/src/train_char_gpt2_koef_128.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_256.py b/src/train_char_gpt2_koef_256.py new file mode 100644 index 0000000..d4aa19e --- /dev/null +++ b/src/train_char_gpt2_koef_256.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_512.py b/src/train_char_gpt2_koef_512.py new file mode 100644 index 0000000..093084a --- /dev/null +++ b/src/train_char_gpt2_koef_512.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_64.py b/src/train_char_gpt2_koef_64.py new file mode 100644 index 0000000..6b68bba --- /dev/null +++ b/src/train_char_gpt2_koef_64.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_128.py b/src/train_optics_char_gpt2_128.py new file mode 100644 index 0000000..1933153 --- /dev/null +++ b/src/train_optics_char_gpt2_128.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_256.py b/src/train_optics_char_gpt2_256.py new file mode 100644 index 0000000..2634f9b --- /dev/null +++ b/src/train_optics_char_gpt2_256.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_512.py b/src/train_optics_char_gpt2_512.py new file mode 100644 index 0000000..e4eb00d --- /dev/null +++ b/src/train_optics_char_gpt2_512.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/optics_char_gpt2_new_formula.py b/src/train_optics_char_gpt2_64.py similarity index 51% rename from src/optics_char_gpt2_new_formula.py rename to src/train_optics_char_gpt2_64.py index 6b59105..877c41c 100644 --- a/src/optics_char_gpt2_new_formula.py +++ b/src/train_optics_char_gpt2_64.py @@ -2,19 +2,31 @@ import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from pathlib import Path import sys -torch.manual_seed(1337) +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) -#################################### Model ######################################### +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 -# def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: -# return matrix / (max_val + 1e-10) +############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 @@ -130,12 +142,12 @@ class TransformerLayer(nn.Module): x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class OpticGPT2NewFormula(nn.Module): +class OpticGPT2(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len != 512: + if max_seq_len < 512: self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, @@ -164,7 +176,7 @@ class OpticGPT2NewFormula(nn.Module): distance = 0.01, trainable_cylind_lens=False) ) - if max_seq_len == 512: + if max_seq_len >= 512: self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, @@ -195,8 +207,6 @@ class OpticGPT2NewFormula(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) - self.sim_scores = omm.DataParallel(self.sim_scores) - self.sim_output = omm.DataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, @@ -224,4 +234,240 @@ class OpticGPT2NewFormula(nn.Module): probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_ff.py b/src/train_optics_char_gpt2_ff.py new file mode 100644 index 0000000..ac21ec3 --- /dev/null +++ b/src/train_optics_char_gpt2_ff.py @@ -0,0 +1,472 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +class OpticLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias = True, + device = None, + dtype = None, + pixel_size = 3.6e-6 + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((in_features, out_features), **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.k = nn.Parameter(torch.randn(1)) + self.sim = omm.OpticalMul( + omm.Config( + right_matrix_count_columns = out_features , + right_matrix_count_rows = in_features, + right_matrix_width = pixel_size * out_features , + right_matrix_height = pixel_size * in_features, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.reset_parameters() + + def forward(self, input): + """ + Runs the forward pass. + """ + return self.k * new_formula(self.sim, input, self.weight) + self.bias + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = OpticLinear(h_dim, 4*h_dim) + self.ff2 = OpticLinear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2FF(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2FF +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_128.py b/src/train_optics_char_gpt2_nokoef_128.py new file mode 100644 index 0000000..b1d0977 --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_128.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_256.py b/src/train_optics_char_gpt2_nokoef_256.py new file mode 100644 index 0000000..905408c --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_256.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_512.py b/src/train_optics_char_gpt2_nokoef_512.py new file mode 100644 index 0000000..bd9b26c --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_512.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/train_optics_char_gpt2_nokoef_64.py similarity index 51% rename from src/optics_char_gpt2_nokoef_newf.py rename to src/train_optics_char_gpt2_nokoef_64.py index d99f149..7dc825a 100644 --- a/src/optics_char_gpt2_nokoef_newf.py +++ b/src/train_optics_char_gpt2_nokoef_64.py @@ -2,16 +2,31 @@ import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from pathlib import Path import sys -torch.manual_seed(1337) +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 -#################################### Model ######################################### +############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 @@ -125,7 +140,7 @@ class TransformerLayer(nn.Module): x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class OpticGPT2NOKoefNewF(nn.Module): +class OpticGPT2(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size = 3.6e-6): super().__init__() @@ -190,8 +205,6 @@ class OpticGPT2NOKoefNewF(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) - self.sim_scores = omm.ScatterDataParallel(self.sim_scores) - self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, @@ -219,4 +232,240 @@ class OpticGPT2NOKoefNewF(nn.Module): probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file