Trainable lens experiments. Refactor perplexity. Bert experiments.

pull/1/head
Vladimir Protsenko 2 months ago
parent 51147b36b3
commit 58b3271cc8

2
.gitignore vendored

@ -214,3 +214,5 @@ __marimo__/
# Streamlit
.streamlit/secrets.toml
checkpoints/

@ -0,0 +1,475 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
# (a_shifted - shift_a)(b_shifted - shift_b) =
# a_shifted*b_shifted - a_shifted*shift_b - b_shifted*shift_a + shift_a*shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,483 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,473 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# self.k1 = nn.Parameter(torch.randn(1))
# self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -0,0 +1,483 @@
import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# self.k1 = nn.Parameter(torch.randn(1))
# self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()

@ -5,6 +5,12 @@ __version__ = "3.0.0"
from .config import Config
from . import propagator
from .optical_mul import OpticalMul
from .optical_mul import (
OpticalMul,
TrainableLensOpticalMul,
TrainableScalarOpticalMul,
TrainableScalarAndLensOpticalMul,
TrainableFocalDistLensOpticalMul
)
from .parallel import DataParallel
from .parallel import ScatterDataParallel

@ -1,7 +1,15 @@
import torch as _torch
import torch.nn as _nn
from .config import Config as _Config
from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop
from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorCylindLens as _PropCylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop
from .propagator import (
PropagatorTrainableCylindLens as _PropagatorTrainableCylindLens,
PropagatorTrainableFocalDistCylindLens as _PropagatorTrainableFocalDistCylindLens
)
from torch.utils.tensorboard import SummaryWriter
from typing import Optional
import matplotlib.pyplot as plt
class OpticalMul(_nn.Module):
"""
@ -14,22 +22,129 @@ class OpticalMul(_nn.Module):
Args:
config: конфигурация расчётной системы.
"""
super(OpticalMul, self).__init__()
self.trainable_cylind_lens = config._trainable_cylind_lens
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
class TrainableScalarOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropСylindLens(config.matrix_plane, config, trainable=self.trainable_cylind_lens)
prop_four = _PropCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
if self.trainable_cylind_lens:
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_between = prop_four
else:
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
@ -38,6 +153,120 @@ class OpticalMul(_nn.Module):
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
self.k = nn.Parameter(_torch.tensor(1))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
return self.k * self.prepare_out(vec_field)
class TrainableLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_cylind_lens: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
@ -69,6 +298,8 @@ class OpticalMul(_nn.Module):
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
# TODO data should be at least two seq length. For one we get
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
@ -114,12 +345,308 @@ class OpticalMul(_nn.Module):
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
if self.trainable_cylind_lens:
vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_between(vec_field)
else:
vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
@_torch.no_grad()
def log_cylind_lens_operator_x(
self,
writer: SummaryWriter,
tag: str,
global_step: Optional[int] = None,
):
# 1. Apply exp to get the wrapped phase as it would be physically seen
# This ensures values outside [-pi, pi] wrap correctly
complex_op = _torch.exp(-1j * self._propagator_cylind_lens._operator_X_phi)
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
# 2. Normalize for Image Visualization [0, 1]
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
# 3. Log as a 1-pixel high row
# Shape: [1, 1, Width]
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
vec_field = self._propagator_two(vec_field * mat_field)
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}")
ax.set_xlabel("Pixel Index")
ax.set_ylabel("Phase (rad)")
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
ax.grid(True, linestyle='--', alpha=0.6)
# Send the figure to the "Plots" or "Images" tab in TensorBoard
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
plt.close(fig) # Important: prevent memory leaks
class TrainableFocalDistLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableFocalDistCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_cylind_lens: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
# TODO data should be at least two seq length. For one we get
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.prepare_out(vec_field)
@_torch.no_grad()
def log_cylind_lens_operator_x(
self,
writer: SummaryWriter,
tag: str,
global_step: Optional[int] = None,
):
# 1. Apply exp to get the wrapped phase as it would be physically seen
# This ensures values outside [-pi, pi] wrap correctly
lens = self._propagator_cylind_lens
writer.add_scalar(f"{tag}/focal_distance", lens._distance.detach().cpu().numpy(), global_step)
complex_op = _torch.exp(-1j * lens._K / lens._distance * lens._linspace_by_x**2)
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
# 2. Normalize for Image Visualization [0, 1]
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
# 3. Log as a 1-pixel high row
# Shape: [1, 1, Width]
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}\nFocal distance: {lens._distance.detach().cpu().numpy()}")
ax.set_xlabel("Pixel Index")
ax.set_ylabel("Phase (rad)")
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
ax.grid(True, linestyle='--', alpha=0.6)
# Send the figure to the "Plots" or "Images" tab in TensorBoard
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
plt.close(fig) # Important: prevent memory leaks
class TrainableScalarAndLensOpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super().__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_two: _Prop = prop_four
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
self.k = nn.Parameter(torch.tensor(1))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_two(vec_field, mat_field.shape[-2:])
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
return self.k * self.prepare_out(vec_field)

@ -7,6 +7,7 @@ from typing import Tuple as _Tuple, Sequence as _Sequence
from abc import ABC as _ABC
import collections as _collections
import copy as _copy
class Propagator(_ABC, _nn.Module):
"""
@ -16,18 +17,12 @@ class Propagator(_ABC, _nn.Module):
operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс
operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат
"""
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False):
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor):
super(Propagator, self).__init__()
operator_X: _torch.Tensor = _torch.view_as_real(operator_X)
operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y)
if trainable:
self._operator_X = _nn.Parameter(operator_X)
self._operator_Y = _nn.Parameter(operator_Y)
self._trainable = trainable
else:
self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
self._trainable = trainable
self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
@ -98,7 +93,14 @@ class Propagator(_ABC, _nn.Module):
"""
return self.cat(propagator)
def forward(self, field: _torch.Tensor) -> _torch.Tensor:
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
@ -109,13 +111,23 @@ class Propagator(_ABC, _nn.Module):
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if self._trainable:
return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X)
else:
return self.operator_Y @ field @ self.operator_X
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X
def __repr__(self):
return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
class PropagatorLens(Propagator):
"""
@ -145,7 +157,7 @@ class PropagatorCrossLens(PropagatorLens):
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase, trainable = False):
config: _ConfigOpticBase):
"""
Конструктор класса скрещенной линзы.
@ -155,18 +167,17 @@ class PropagatorCrossLens(PropagatorLens):
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2)
super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y),
trainable)
super(PropagatorCrossLens, self).__init__(
_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorСylindLens(PropagatorLens):
class PropagatorCylindLens(PropagatorLens):
"""
Класс распространения света в цилиндрической линзе,
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase,
trainable = False):
config: _ConfigOpticBase):
"""
Конструктор класса цилиндрической линзы.
@ -176,14 +187,10 @@ class PropagatorСylindLens(PropagatorLens):
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
if trainable:
super(PropagatorСylindLens, self).__init__(operator_X,
operator_Y,
trainable)
else:
super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y),
trainable)
super(PropagatorCylindLens, self).__init__(
_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorSinc(Propagator):
"""
@ -192,7 +199,7 @@ class PropagatorSinc(Propagator):
"""
def __init__(self, first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase, trainable = False):
config: _ConfigOpticBase):
"""
Конструктор класса распространения в свободном пространстве.
@ -204,7 +211,7 @@ class PropagatorSinc(Propagator):
operator_X, operator_Y = self.__get_operators(first_plane,
second_plane,
config)
super(PropagatorSinc, self).__init__(operator_X, operator_Y, trainable)
super(PropagatorSinc, self).__init__(operator_X, operator_Y)
def __get_operator_for_dim(self,
pixel_size_in: float,
@ -238,3 +245,136 @@ class PropagatorSinc(Propagator):
difference_y,
config)
return operator_X, operator_Y
#######################################################################################################################
class PropagatorTrainableCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
# non smooth profile after training. better to train only focal length?
self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X
class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
self._distance = _nn.Parameter(_torch.tensor(config.distance))
self.register_buffer('_K', _torch.tensor(config.K), persistent=True)
self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X

@ -179,6 +179,8 @@ class OpticGPT2NOKoef(nn.Module):
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,

@ -0,0 +1,341 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py
MODEL_CLASS = GPT2
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,399 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndFocalDistLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableFocalDistLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(
h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndFocalDistLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)

@ -0,0 +1,464 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,464 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,464 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 10 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -0,0 +1,464 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
import optical_matrix_multiplication as omm
import shutil
seed = 1337
torch.manual_seed(seed)
############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
# now trainable parameters are in optical mat mul class, one scalar per simulator
# here we use only TrainableLensOpticalMul and scalars for each layer
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainableScalarAndLens(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if max_seq_len != 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
if max_seq_len == 512:
self.sim_scores = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_output = omm.TrainableLensOpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def log_trainable_optic_params(
self,
writer: SummaryWriter,
global_step,
):
self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step)
self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step)
for i, layer in enumerate(self.layers):
# Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1)
writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step)
writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step)
###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens")
test_data_path = Path("./data/wiki.test.tokens")
comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name)
print("Logs dir:", logs_dir)
# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
model.eval()
stride = max(1, len(data) // 10000)
total_loss_sum = 0.0
total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(
vocab_size=vocab_size,
h_dim=h_dim,
max_seq_len=max_seq_len,
num_heads=num_heads,
pixel_size=pixel_size,
layers_num=layers_num
)
m = m.to(device)
writer.add_text('model', str(m), 0)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
start_time = datetime.now()
print("Started at:", start_time)
m.eval()
task_prompts = [
"1 2 3 4 5",
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, 0)
for i in range(max_iters):
m.train()
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for j in range(gradient_accumulation_steps):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps)
logits, loss = m(xb, yb)
loss = loss / gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
if i % 100 == 0:
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i)
optimizer.step()
writer.add_scalar('loss', accumulated_loss, i)
print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0:
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval()
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n\n"), max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results)
writer.add_text('completions/task', task_results, i+1)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")
Loading…
Cancel
Save