Merge pull request 'traindiag' (#1) from traindiag into main

Reviewed-on: #1
main
Vladimir Protsenko 1 month ago
commit b8ee7f43cd

2
.gitignore vendored

@ -214,3 +214,5 @@ __marimo__/
# Streamlit # Streamlit
.streamlit/secrets.toml .streamlit/secrets.toml
checkpoints/

@ -0,0 +1,93 @@
import torch
import torch.nn as nn
import optical_matrix_multiplication as omm
import matplotlib.pyplot as plt
device = 'cpu'
h_dim = 64
pixel_size = 3.6e-6
batch_size = 100
test_lengths = [59, 64, 128, 256, 512]
for max_seq_len in test_lengths:
if max_seq_len < 512:
sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * h_dim,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
lens_size = 8192)
)
sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * h_dim,
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
lens_size = 8192)
)
else:
sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * h_dim,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * h_dim,
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
def cko(x,y):
x = x**2
y = y**2
return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100
sim_scores = sim_scores.to(device=device)
sim_output = sim_output.to(device=device)
q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1)
true_scores = q @ k
opt_scores = sim_scores(q, k)
CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy()
scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device)
v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
true_o = scores @ v
opt_o = sim_output(scores, v)
CKO_o = cko(true_o, opt_o).detach().cpu().numpy()
print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}")
print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}")

@ -0,0 +1,477 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
# (a_shifted - shift_a)(b_shifted - shift_b) =
# a_shifted*b_shifted - a_shifted*shift_b - b_shifted*shift_a + shift_a*shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}'
writer.add_text('model', model_description, 0)
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,485 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}'
writer.add_text('model', model_description, 0)
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,475 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# self.k1 = nn.Parameter(torch.randn(1))
# self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}'
writer.add_text('model', model_description, 0)
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,485 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# self.k1 = nn.Parameter(torch.randn(1))
# self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}'
writer.add_text('model', model_description, 0)
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -1,168 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2
from bpe_tokenizer import byte_pair_init, byte_pair_encode, byte_pair_decode
seed = 1337
torch.manual_seed(seed)
models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2}
batch_size = 50
max_iters = 40000*10
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 22
h_dim = 64
max_seq_len = 256
num_heads = 4
dropout_rate = 0.1
pixel_size = 3.6e-6
merges_count = 20
# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment
MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2])
val_data_path = Path(sys.argv[3])
test_data_path = Path(sys.argv[4])
comment = f"bpe_{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
print("Logs dir:", logs_dir)
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
script_snapshot_path.chmod(0o400) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
print(f"Len chars: {len(chars)}")
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
import pickle
start_time = datetime.now()
if Path("./data/bpe_text.pkl").exists() and Path("./data/merges.pkl").exists():
with open("./data/bpe_text.pkl", 'rb') as f: bpe_text = pickle.load(f)
with open("./data/merges.pkl", 'rb') as f: merges = pickle.load(f)
else:
bpe_text, merges = byte_pair_init([wtoi[w] for w in text], vocab_size=len(chars), merges_count=20)
with open("./data/bpe_text.pkl", 'wb') as f: pickle.dump(bpe_text, f)
with open("./data/merges.pkl", 'wb') as f: pickle.dump(merges, f)
print(f"Compression ratio: {len(text)/len(bpe_text)}, init took {datetime.now()-start_time}")
vocab_size = len(chars) + merges_count
encode = lambda s: byte_pair_encode([wtoi[w] for w in s], merges)
decode = lambda idx: "".join([itow[i] for i in byte_pair_decode(idx, merges)])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
start_time = datetime.now()
train_bpe_encoded_path = Path("./data/train_bpe_encoded.pt")
val_bpe_encoded_path = Path("./data/val_bpe_encoded.pt")
test_bpe_encoded_path = Path("./data/test_bpe_encoded.pt")
if train_bpe_encoded_path.exists() and val_bpe_encoded_path.exists() and test_bpe_encoded_path.exists():
train_data = torch.load(train_bpe_encoded_path)
val_data = torch.load(val_bpe_encoded_path)
test_data = torch.load(test_bpe_encoded_path)
else:
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
torch.save(train_data, train_bpe_encoded_path)
torch.save(val_data, val_bpe_encoded_path)
torch.save(test_data, test_bpe_encoded_path)
print(f"Encoded {datetime.now() - start_time}")
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
print(m)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
for i in range(max_iters):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size)
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), i)
print(f"\r{i}/{max_iters} {loss.item()}", end="")
if i % 5000 == 0:
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\rPerplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {loss.item()}")
print(f"\rPerplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', loss.item(), i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)

@ -1,116 +0,0 @@
import numpy as np
import pandas as pd
def get_top_pair(tokens):
hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.int32)
return list(hist.nlargest(1, columns='count').iloc[0, [0,1]])
def merge(tokens, pair, new_idx):
new_tokens = []
skip = False
for a,b in zip(tokens[:-1], tokens[1:]):
if skip:
skip = False
continue
if a == pair[0] and b == pair[1]:
new_tokens.append(new_idx)
skip = True
else:
new_tokens.append(a)
if not skip:
new_tokens.append(b)
return np.array(new_tokens)
def unmerge(tokens, pair_idx, pair):
new_tokens = []
for idx in tokens:
if idx == pair_idx:
new_tokens.append(pair[0])
new_tokens.append(pair[1])
else:
new_tokens.append(idx)
return new_tokens
def byte_pair_init(char_ids, vocab_size, merges_count=20):
byte_text = np.array(char_ids)
merges = []
for i in range(merges_count):
top_pair = get_top_pair(byte_text)
new_idx = vocab_size + i
merges.append([top_pair, new_idx])
print(f"{top_pair} {new_idx}")
byte_text = merge(byte_text, top_pair, new_idx)
return np.array(byte_text), merges
def byte_pair_encode(char_ids, merges):
tokens = np.array(char_ids)
for pair, pair_idx in merges:
tokens = merge(tokens, pair, pair_idx)
return tokens
def byte_pair_decode(tokens, merges):
for pair, pair_idx in reversed(merges):
tokens = unmerge(tokens, pair_idx, pair)
return tokens
# def get_top_pair(tokens):
# hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.uint16)
# return np.array(hist.nlargest(1, columns='count').iloc[0, [0,1]])
# def merge(tokens, pair, new_idx):
# if len(tokens) % 2 != 0:
# tokens = np.append(tokens, np.array([0], dtype=np.uint16))
# # print("not even")
# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy()
# b = np.frombuffer(bytes(pair), dtype=np.uint32)
# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32)
# a[a==b] = c
# d = np.frombuffer(bytes(a), dtype=np.uint16)
# indices = np.where(d == 2**16-1)
# e = np.delete(d, indices)
# e = e[:-1]
# else:
# # print("even")
# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy()
# b = np.frombuffer(bytes(pair), dtype=np.uint32)
# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32)
# a[a==b] = c
# d = np.frombuffer(bytes(a), dtype=np.uint16)
# indices = np.where(d == 2**16-1)
# e = np.delete(d, indices)
# return e
# def unmerge(tokens, pair_idx, pair):
# new_tokens = []
# for idx in tokens:
# if idx == pair_idx:
# new_tokens.append(pair[0])
# new_tokens.append(pair[1])
# else:
# new_tokens.append(idx)
# return new_tokens
# def byte_pair_init(char_ids, vocab_size, merges_count=20):
# assert vocab_size < 2**16
# byte_text = np.array(char_ids, dtype=np.uint16)
# merges = []
# for i in range(merges_count):
# top_pair = get_top_pair(byte_text)
# new_idx = vocab_size + i
# print([top_pair, new_idx])
# merges.append([top_pair, new_idx])
# byte_text = merge(byte_text, top_pair, new_idx)
# byte_text = np.roll(merge(np.roll(byte_text, 1), top_pair, new_idx), -1)
# return byte_text, merges
# def byte_pair_encode(char_ids, merges):
# tokens = np.array(char_ids, dtype=np.uint16)
# for pair, pair_idx in merges:
# tokens = merge(tokens, pair, pair_idx)
# return tokens
# def byte_pair_decode(tokens, merges):
# for pair, pair_idx in reversed(merges):
# tokens = unmerge(tokens, pair_idx, pair)
# return tokens

@ -1,115 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
torch.manual_seed(1337)
#################################### Model #########################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -1,136 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2
seed = 1337
torch.manual_seed(seed)
models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2}
batch_size = 50
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 22
h_dim = 64
max_seq_len = 256
num_heads = 4
dropout_rate = 0.1
pixel_size = 3.6e-6
# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment
MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2])
val_data_path = Path(sys.argv[3])
test_data_path = Path(sys.argv[4])
comment = f"{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
print("Logs dir:", logs_dir)
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
script_snapshot_path.chmod(0o400) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
for i in range(max_iters):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size)
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), i)
print(f"\r{i}/{max_iters} {loss.item()}", end="")
if i % 5000 == 0:
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\rPerplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {loss.item()}")
print(f"\rPerplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', loss.item(), i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)

@ -5,5 +5,12 @@ __version__ = "3.0.0"
from .config import Config from .config import Config
from . import propagator from . import propagator
from .optical_mul import OpticalMul from .optical_mul import (
from .parallel import DataParallel OpticalMul,
TrainableLensOpticalMul,
TrainableScalarOpticalMul,
TrainableScalarAndLensOpticalMul,
TrainableFocalDistLensOpticalMul
)
from .parallel import DataParallel
from .parallel import ScatterDataParallel

@ -1,7 +1,15 @@
import torch as _torch import torch as _torch
import torch.nn as _nn import torch.nn as _nn
from .config import Config as _Config from .config import Config as _Config
from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorCylindLens as _PropCylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop
from .propagator import (
PropagatorTrainableCylindLens as _PropagatorTrainableCylindLens,
PropagatorTrainableFocalDistCylindLens as _PropagatorTrainableFocalDistCylindLens
)
from torch.utils.tensorboard import SummaryWriter
from typing import Optional
import matplotlib.pyplot as plt
class OpticalMul(_nn.Module): class OpticalMul(_nn.Module):
""" """
@ -14,12 +22,124 @@ class OpticalMul(_nn.Module):
Args: Args:
config: конфигурация расчётной системы. config: конфигурация расчётной системы.
""" """
super(OpticalMul, self).__init__() super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
class TrainableScalarOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropСylindLens(config.matrix_plane, config) prop_four = _PropCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
@ -33,6 +153,421 @@ class OpticalMul(_nn.Module):
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
self.k = nn.Parameter(_torch.tensor(1))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
return self.k * self.prepare_out(vec_field)
class TrainableLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_cylind_lens: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
# TODO data should be at least two seq length. For one we get
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
@_torch.no_grad()
def log_cylind_lens_operator_x(
self,
writer: SummaryWriter,
tag: str,
global_step: Optional[int] = None,
):
# 1. Apply exp to get the wrapped phase as it would be physically seen
# This ensures values outside [-pi, pi] wrap correctly
complex_op = _torch.exp(-1j * self._propagator_cylind_lens._operator_X_phi)
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
# 2. Normalize for Image Visualization [0, 1]
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
# 3. Log as a 1-pixel high row
# Shape: [1, 1, Width]
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}")
ax.set_xlabel("Pixel Index")
ax.set_ylabel("Phase (rad)")
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
ax.grid(True, linestyle='--', alpha=0.6)
# Send the figure to the "Plots" or "Images" tab in TensorBoard
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
plt.close(fig) # Important: prevent memory leaks
class TrainableFocalDistLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableFocalDistCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_cylind_lens: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
# TODO data should be at least two seq length. For one we get
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
@_torch.no_grad()
def log_cylind_lens_operator_x(
self,
writer: SummaryWriter,
tag: str,
global_step: Optional[int] = None,
):
# 1. Apply exp to get the wrapped phase as it would be physically seen
# This ensures values outside [-pi, pi] wrap correctly
lens = self._propagator_cylind_lens
writer.add_scalar(f"{tag}/focal_distance", lens._distance.detach().cpu().numpy(), global_step)
complex_op = _torch.exp(-1j * lens._K / lens._distance * lens._linspace_by_x**2)
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
# 2. Normalize for Image Visualization [0, 1]
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
# 3. Log as a 1-pixel high row
# Shape: [1, 1, Width]
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}\nFocal distance: {lens._distance.detach().cpu().numpy()}")
ax.set_xlabel("Pixel Index")
ax.set_ylabel("Phase (rad)")
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
ax.grid(True, linestyle='--', alpha=0.6)
# Send the figure to the "Plots" or "Images" tab in TensorBoard
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
plt.close(fig) # Important: prevent memory leaks
class TrainableScalarAndLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_two: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
self.k = nn.Parameter(torch.tensor(1))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
""" """
@ -110,7 +645,8 @@ class OpticalMul(_nn.Module):
vec_field = self.prepare_vector(input) vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other) mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field) vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field * mat_field) vec_field = self._propagator_two(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field) return self.k * self.prepare_out(vec_field)

@ -94,4 +94,68 @@ class DataParallel(_nn.Module):
outputs = _nn.parallel.parallel_apply(replicas, stacked_input) outputs = _nn.parallel.parallel_apply(replicas, stacked_input)
return _nn.parallel.gather(outputs, self.output_device, dim) return _nn.parallel.gather(outputs, self.output_device, dim)
class ScatterDataParallel(_nn.Module):
"""
Оптимизированный DataParallel для работы с attention матрицами разных размеров.
Эквивалентно DataParallel от PyTorch?
"""
def __init__(self, module: _nn.Module, devices: Union[None, List[Union[int, _torch.device]]] = None,
output_device: Union[int, _torch.device] = None) -> None:
super(ScatterDataParallel, self).__init__()
if not _torch.cuda.is_available():
raise EnvironmentError("cuda is not available.")
if not devices:
devices = [_torch.device(f'cuda:{x}') for x in range(_torch.cuda.device_count())]
if not output_device:
output_device = devices[0]
self.module = module
self.devices = devices
self.output_device = output_device
def buffers(self, *inputs) -> Iterator[_torch.Tensor]:
return self.module.buffers(*inputs)
def parameters(self, *inputs) -> Iterator[_nn.parameter.Parameter]:
return self.module.parameters(*inputs)
def forward(self, input: _torch.Tensor, other: _torch.Tensor, **kwargs: Any) -> _torch.Tensor:
'''
Оптимизированный forward для attention матриц.
Особенности:
- Scatter по batch dimension (1) вместо произвольного dim
- Оба тензора scatter'ятся для согласованности размерностей
- Поддержка многомерных attention тензоров [batch, heads, seq, dim] ??
'''
# Определяем dimension для scatter на основе структуры тензоров
if input.dim() >= 3 and other.dim() >= 3:
# Для attention матриц scatter по batch dimension
scatter_dim = 1
else:
# Для обычных 2D матриц используем dim из kwargs или по умолчанию 2
scatter_dim = kwargs.get('dim', 2)
# Подготовка модуля и данных
self.module = self.module.to(self.devices[0])
replicas = _nn.parallel.replicate(self.module, self.devices)
# Scatter ОБОИХ тензоров для согласованности размерностей
scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim)
scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim)
# Формируем входные данные для каждого устройства
# Убедимся, что все списки одинаковой длины
min_len = min(len(scattered_input), len(scattered_other), len(replicas))
stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)]
# Параллельное вычисление
outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input)
# Сбор результатов
return _nn.parallel.gather(outputs, self.output_device, scatter_dim)

@ -7,6 +7,7 @@ from typing import Tuple as _Tuple, Sequence as _Sequence
from abc import ABC as _ABC from abc import ABC as _ABC
import collections as _collections import collections as _collections
import copy as _copy
class Propagator(_ABC, _nn.Module): class Propagator(_ABC, _nn.Module):
""" """
@ -92,7 +93,14 @@ class Propagator(_ABC, _nn.Module):
""" """
return self.cat(propagator) return self.cat(propagator)
def forward(self, field: _torch.Tensor) -> _torch.Tensor: @staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
""" """
Метод распространения светового поля в среде. Метод распространения светового поля в среде.
@ -103,8 +111,23 @@ class Propagator(_ABC, _nn.Module):
Распределение комплексной амплитуды светового поля, Распределение комплексной амплитуды светового поля,
после распространения. после распространения.
""" """
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X return self.operator_Y @ field @ self.operator_X
def __repr__(self):
return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
class PropagatorLens(Propagator): class PropagatorLens(Propagator):
""" """
Абстрактный класс распространения света в тонком оптическом элементе. Абстрактный класс распространения света в тонком оптическом элементе.
@ -143,10 +166,11 @@ class PropagatorCrossLens(PropagatorLens):
""" """
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2)
super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), super(PropagatorCrossLens, self).__init__(
_torch.diag_embed(operator_Y)) _torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorСylindLens(PropagatorLens): class PropagatorCylindLens(PropagatorLens):
""" """
Класс распространения света в цилиндрической линзе, Класс распространения света в цилиндрической линзе,
представленной тонким оптическим элементом. представленной тонким оптическим элементом.
@ -162,8 +186,10 @@ class PropagatorСylindLens(PropagatorLens):
""" """
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), super(PropagatorCylindLens, self).__init__(
_torch.diag_embed(operator_Y)) _torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorSinc(Propagator): class PropagatorSinc(Propagator):
""" """
@ -218,3 +244,136 @@ class PropagatorSinc(Propagator):
difference_y, difference_y,
config) config)
return operator_X, operator_Y return operator_X, operator_Y
#######################################################################################################################
class PropagatorTrainableCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
# non smooth profile after training. better to train only focal length?
self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X
class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
self._distance = _nn.Parameter(_torch.tensor(config.distance))
self.register_buffer('_K', _torch.tensor(config.K), persistent=True)
self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X

@ -1,177 +0,0 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,369 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,473 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,472 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
class OpticLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias = True,
device = None,
dtype = None,
pixel_size = 3.6e-6
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((in_features, out_features), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.k = nn.Parameter(torch.randn(1))
self.sim = omm.OpticalMul(
omm.Config(
right_matrix_count_columns = out_features ,
right_matrix_count_rows = in_features,
right_matrix_width = pixel_size * out_features ,
right_matrix_height = pixel_size * in_features,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.reset_parameters()
def forward(self, input):
"""
Runs the forward pass.
"""
return self.k * new_formula(self.sim, input, self.weight) + self.bias
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def extra_repr(self) -> str:
"""
Return the extra representation of the module.
"""
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = OpticLinear(h_dim, 4*h_dim)
self.ff2 = OpticLinear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2FF(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2FF
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,471 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
MODEL_CLASS = OpticGPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,490 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndFocalDistLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(
h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndFocalDistLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,490 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,491 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,491 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,490 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
# Precompute all valid start positions
start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
total_sequences = len(start_positions)
# Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")
Loading…
Cancel
Save