scripts version

pull/1/head
Vladimir Protsenko 1 month ago
parent 1ace114d0c
commit 1b8ddc31d4

@ -1,168 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2
from bpe_tokenizer import byte_pair_init, byte_pair_encode, byte_pair_decode
seed = 1337
torch.manual_seed(seed)
models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2}
batch_size = 50
max_iters = 40000*10
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 22
h_dim = 64
max_seq_len = 256
num_heads = 4
dropout_rate = 0.1
pixel_size = 3.6e-6
merges_count = 20
# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment
MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2])
val_data_path = Path(sys.argv[3])
test_data_path = Path(sys.argv[4])
comment = f"bpe_{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
print("Logs dir:", logs_dir)
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
script_snapshot_path.chmod(0o400) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
print(f"Len chars: {len(chars)}")
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
import pickle
start_time = datetime.now()
if Path("./data/bpe_text.pkl").exists() and Path("./data/merges.pkl").exists():
with open("./data/bpe_text.pkl", 'rb') as f: bpe_text = pickle.load(f)
with open("./data/merges.pkl", 'rb') as f: merges = pickle.load(f)
else:
bpe_text, merges = byte_pair_init([wtoi[w] for w in text], vocab_size=len(chars), merges_count=20)
with open("./data/bpe_text.pkl", 'wb') as f: pickle.dump(bpe_text, f)
with open("./data/merges.pkl", 'wb') as f: pickle.dump(merges, f)
print(f"Compression ratio: {len(text)/len(bpe_text)}, init took {datetime.now()-start_time}")
vocab_size = len(chars) + merges_count
encode = lambda s: byte_pair_encode([wtoi[w] for w in s], merges)
decode = lambda idx: "".join([itow[i] for i in byte_pair_decode(idx, merges)])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
start_time = datetime.now()
train_bpe_encoded_path = Path("./data/train_bpe_encoded.pt")
val_bpe_encoded_path = Path("./data/val_bpe_encoded.pt")
test_bpe_encoded_path = Path("./data/test_bpe_encoded.pt")
if train_bpe_encoded_path.exists() and val_bpe_encoded_path.exists() and test_bpe_encoded_path.exists():
train_data = torch.load(train_bpe_encoded_path)
val_data = torch.load(val_bpe_encoded_path)
test_data = torch.load(test_bpe_encoded_path)
else:
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
torch.save(train_data, train_bpe_encoded_path)
torch.save(val_data, val_bpe_encoded_path)
torch.save(test_data, test_bpe_encoded_path)
print(f"Encoded {datetime.now() - start_time}")
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
print(m)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
for i in range(max_iters):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size)
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), i)
print(f"\r{i}/{max_iters} {loss.item()}", end="")
if i % 5000 == 0:
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\rPerplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {loss.item()}")
print(f"\rPerplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', loss.item(), i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)

@ -1,116 +0,0 @@
import numpy as np
import pandas as pd
def get_top_pair(tokens):
hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.int32)
return list(hist.nlargest(1, columns='count').iloc[0, [0,1]])
def merge(tokens, pair, new_idx):
new_tokens = []
skip = False
for a,b in zip(tokens[:-1], tokens[1:]):
if skip:
skip = False
continue
if a == pair[0] and b == pair[1]:
new_tokens.append(new_idx)
skip = True
else:
new_tokens.append(a)
if not skip:
new_tokens.append(b)
return np.array(new_tokens)
def unmerge(tokens, pair_idx, pair):
new_tokens = []
for idx in tokens:
if idx == pair_idx:
new_tokens.append(pair[0])
new_tokens.append(pair[1])
else:
new_tokens.append(idx)
return new_tokens
def byte_pair_init(char_ids, vocab_size, merges_count=20):
byte_text = np.array(char_ids)
merges = []
for i in range(merges_count):
top_pair = get_top_pair(byte_text)
new_idx = vocab_size + i
merges.append([top_pair, new_idx])
print(f"{top_pair} {new_idx}")
byte_text = merge(byte_text, top_pair, new_idx)
return np.array(byte_text), merges
def byte_pair_encode(char_ids, merges):
tokens = np.array(char_ids)
for pair, pair_idx in merges:
tokens = merge(tokens, pair, pair_idx)
return tokens
def byte_pair_decode(tokens, merges):
for pair, pair_idx in reversed(merges):
tokens = unmerge(tokens, pair_idx, pair)
return tokens
# def get_top_pair(tokens):
# hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.uint16)
# return np.array(hist.nlargest(1, columns='count').iloc[0, [0,1]])
# def merge(tokens, pair, new_idx):
# if len(tokens) % 2 != 0:
# tokens = np.append(tokens, np.array([0], dtype=np.uint16))
# # print("not even")
# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy()
# b = np.frombuffer(bytes(pair), dtype=np.uint32)
# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32)
# a[a==b] = c
# d = np.frombuffer(bytes(a), dtype=np.uint16)
# indices = np.where(d == 2**16-1)
# e = np.delete(d, indices)
# e = e[:-1]
# else:
# # print("even")
# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy()
# b = np.frombuffer(bytes(pair), dtype=np.uint32)
# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32)
# a[a==b] = c
# d = np.frombuffer(bytes(a), dtype=np.uint16)
# indices = np.where(d == 2**16-1)
# e = np.delete(d, indices)
# return e
# def unmerge(tokens, pair_idx, pair):
# new_tokens = []
# for idx in tokens:
# if idx == pair_idx:
# new_tokens.append(pair[0])
# new_tokens.append(pair[1])
# else:
# new_tokens.append(idx)
# return new_tokens
# def byte_pair_init(char_ids, vocab_size, merges_count=20):
# assert vocab_size < 2**16
# byte_text = np.array(char_ids, dtype=np.uint16)
# merges = []
# for i in range(merges_count):
# top_pair = get_top_pair(byte_text)
# new_idx = vocab_size + i
# print([top_pair, new_idx])
# merges.append([top_pair, new_idx])
# byte_text = merge(byte_text, top_pair, new_idx)
# byte_text = np.roll(merge(np.roll(byte_text, 1), top_pair, new_idx), -1)
# return byte_text, merges
# def byte_pair_encode(char_ids, merges):
# tokens = np.array(char_ids, dtype=np.uint16)
# for pair, pair_idx in merges:
# tokens = merge(tokens, pair, pair_idx)
# return tokens
# def byte_pair_decode(tokens, merges):
# for pair, pair_idx in reversed(merges):
# tokens = unmerge(tokens, pair_idx, pair)
# return tokens

@ -1,115 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
torch.manual_seed(1337)
#################################### Model #########################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -1,168 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2
from optics_char_gpt2_traindiag import OpticGPT2TrainDiag
from optics_char_gpt2_ff import OpticGPT2FF
from optics_char_gpt2_new_formula import OpticGPT2NewFormula
from char_gpt2_scaledmatmul import GPT2ScaledMM
from optics_char_gpt2_nokoef import OpticGPT2NOKoef
from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF
import shutil
seed = 1337
torch.manual_seed(seed)
models = {
'gpt2': GPT2,
'optic_gpt2': OpticGPT2,
'optic_gpt2_ff': OpticGPT2FF,
'optic_gpt2_traindiag': OpticGPT2TrainDiag,
'optic_gpt2_newformula': OpticGPT2NewFormula,
'optic_gpt2_nokoef': OpticGPT2NOKoef,
'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF,
'gpt2_scaledmm': GPT2ScaledMM
}
batch_size = 50
gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64
MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2])
val_data_path = Path(sys.argv[3])
test_data_path = Path(sys.argv[4])
comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
m.eval()
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)

@ -1,211 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -1,217 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F, init
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
import math
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
# print(tensor_1.shape, tensor_2.shape)
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,None,:,:]
# print(tensor_1.shape, tensor_2.shape)
# raise RuntimeError
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
class OpticLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias = True,
device = None,
dtype = None,
pixel_size = 3.6e-6
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((in_features, out_features), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.k = nn.Parameter(torch.randn(1))
self.sim = omm.OpticalMul(
omm.Config(
right_matrix_count_columns = out_features ,
right_matrix_count_rows = in_features,
right_matrix_width = pixel_size * out_features ,
right_matrix_height = pixel_size * in_features,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.reset_parameters()
def forward(self, input):
"""
Runs the forward pass.
"""
return self.k * optics_matmul_shift(self.sim, input, self.weight) + self.bias
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def extra_repr(self) -> str:
"""
Return the extra representation of the module.
"""
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = OpticLinear(h_dim, 4*h_dim)
self.ff2 = OpticLinear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2FF(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -1,211 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2NOKoef(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -1,210 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128, pixel_size=3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=True)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=True)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.DataParallel(self.sim_scores)
self.sim_output = omm.DataParallel(self.sim_output)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainDiag(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len, pixel_size=pixel_size)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -11,6 +11,21 @@ import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
@ -117,23 +132,6 @@ class GPT2(nn.Module):
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -2,19 +2,31 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
#################################### Model #########################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
# return matrix / (max_val + 1e-10)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
@ -130,12 +142,12 @@ class TransformerLayer(nn.Module):
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2NewFormula(nn.Module):
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
@ -164,7 +176,7 @@ class OpticGPT2NewFormula(nn.Module):
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len == 512:
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
@ -195,8 +207,6 @@ class OpticGPT2NewFormula(nn.Module):
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_scores = omm.DataParallel(self.sim_scores)
self.sim_output = omm.DataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
@ -224,4 +234,240 @@ class OpticGPT2NewFormula(nn.Module):
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,472 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
class OpticLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias = True,
device = None,
dtype = None,
pixel_size = 3.6e-6
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((in_features, out_features), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.k = nn.Parameter(torch.randn(1))
self.sim = omm.OpticalMul(
omm.Config(
right_matrix_count_columns = out_features ,
right_matrix_count_rows = in_features,
right_matrix_width = pixel_size * out_features ,
right_matrix_height = pixel_size * in_features,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.reset_parameters()
def forward(self, input):
"""
Runs the forward pass.
"""
return self.k * new_formula(self.sim, input, self.weight) + self.bias
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def extra_repr(self) -> str:
"""
Return the extra representation of the module.
"""
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = OpticLinear(h_dim, 4*h_dim)
self.ff2 = OpticLinear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2FF(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2FF
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -2,16 +2,31 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
#################################### Model #########################################
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
@ -125,7 +140,7 @@ class TransformerLayer(nn.Module):
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2NOKoefNewF(nn.Module):
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
@ -190,8 +205,6 @@ class OpticGPT2NOKoefNewF(nn.Module):
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
@ -219,4 +232,240 @@ class OpticGPT2NOKoefNewF(nn.Module):
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")
Loading…
Cancel
Save