From 58b3271cc8920a6571d5cc59f55b6ca73dd8c8ab Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sun, 8 Feb 2026 20:59:51 +0000 Subject: [PATCH] Trainable lens experiments. Refactor perplexity. Bert experiments. --- .gitignore | 2 + src/bert_optica_koef.py | 475 +++++++++++++++ src/bert_optica_koef_newf.py | 483 +++++++++++++++ src/bert_optica_nokoef.py | 473 +++++++++++++++ src/bert_optica_nokoef_newf.py | 483 +++++++++++++++ src/optical_matrix_multiplication/__init__.py | 8 +- .../optical_mul.py | 559 +++++++++++++++++- .../propagator.py | 206 +++++-- src/optics_char_gpt2_nokoef.py | 2 + src/train_gpt2.py | 341 +++++++++++ ...ain_optics_trainable_focal_dist_lens_64.py | 399 +++++++++++++ src/train_optics_trainable_lens_128.py | 464 +++++++++++++++ src/train_optics_trainable_lens_256.py | 464 +++++++++++++++ src/train_optics_trainable_lens_512.py | 464 +++++++++++++++ src/train_optics_trainable_lens_64.py | 464 +++++++++++++++ 15 files changed, 5237 insertions(+), 50 deletions(-) create mode 100644 src/bert_optica_koef.py create mode 100644 src/bert_optica_koef_newf.py create mode 100644 src/bert_optica_nokoef.py create mode 100644 src/bert_optica_nokoef_newf.py create mode 100644 src/train_gpt2.py create mode 100644 src/train_optics_trainable_focal_dist_lens_64.py create mode 100644 src/train_optics_trainable_lens_128.py create mode 100644 src/train_optics_trainable_lens_256.py create mode 100644 src/train_optics_trainable_lens_512.py create mode 100644 src/train_optics_trainable_lens_64.py diff --git a/.gitignore b/.gitignore index e15106e..d2133bf 100644 --- a/.gitignore +++ b/.gitignore @@ -214,3 +214,5 @@ __marimo__/ # Streamlit .streamlit/secrets.toml + +checkpoints/ \ No newline at end of file diff --git a/src/bert_optica_koef.py b/src/bert_optica_koef.py new file mode 100644 index 0000000..634c671 --- /dev/null +++ b/src/bert_optica_koef.py @@ -0,0 +1,475 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + # (a_shifted - shift_a)(b_shifted - shift_b) = + # a_shifted*b_shifted - a_shifted*shift_b - b_shifted*shift_a + shift_a*shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_koef_newf.py b/src/bert_optica_koef_newf.py new file mode 100644 index 0000000..9e5fb1d --- /dev/null +++ b/src/bert_optica_koef_newf.py @@ -0,0 +1,483 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_nokoef.py b/src/bert_optica_nokoef.py new file mode 100644 index 0000000..2c5b249 --- /dev/null +++ b/src/bert_optica_nokoef.py @@ -0,0 +1,473 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # self.k1 = nn.Parameter(torch.randn(1)) + # self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_nokoef_newf.py b/src/bert_optica_nokoef_newf.py new file mode 100644 index 0000000..dcfede5 --- /dev/null +++ b/src/bert_optica_nokoef_newf.py @@ -0,0 +1,483 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # self.k1 = nn.Parameter(torch.randn(1)) + # self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/optical_matrix_multiplication/__init__.py b/src/optical_matrix_multiplication/__init__.py index 9a1844c..d5f8779 100644 --- a/src/optical_matrix_multiplication/__init__.py +++ b/src/optical_matrix_multiplication/__init__.py @@ -5,6 +5,12 @@ __version__ = "3.0.0" from .config import Config from . import propagator -from .optical_mul import OpticalMul +from .optical_mul import ( + OpticalMul, + TrainableLensOpticalMul, + TrainableScalarOpticalMul, + TrainableScalarAndLensOpticalMul, + TrainableFocalDistLensOpticalMul +) from .parallel import DataParallel from .parallel import ScatterDataParallel \ No newline at end of file diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index 11f52dc..cbe5f57 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -1,7 +1,15 @@ import torch as _torch import torch.nn as _nn from .config import Config as _Config -from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop +from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorCylindLens as _PropCylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop +from .propagator import ( + PropagatorTrainableCylindLens as _PropagatorTrainableCylindLens, + PropagatorTrainableFocalDistCylindLens as _PropagatorTrainableFocalDistCylindLens +) +from torch.utils.tensorboard import SummaryWriter +from typing import Optional +import matplotlib.pyplot as plt + class OpticalMul(_nn.Module): """ @@ -14,22 +22,129 @@ class OpticalMul(_nn.Module): Args: config: конфигурация расчётной системы. """ - super(OpticalMul, self).__init__() - self.trainable_cylind_lens = config._trainable_cylind_lens + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + self._propagator_two: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + +class TrainableScalarOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropСylindLens(config.matrix_plane, config, trainable=self.trainable_cylind_lens) + prop_four = _PropCylindLens(config.matrix_plane, config) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - if self.trainable_cylind_lens: - self._propagator_one: _Prop = prop_one + prop_two + prop_three - self._propagator_between = prop_four - else: - self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) @@ -38,6 +153,421 @@ class OpticalMul(_nn.Module): self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + self.k = nn.Parameter(_torch.tensor(1)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.k * self.prepare_out(vec_field) + +class TrainableLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_cylind_lens: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + # TODO data should be at least two seq length. For one we get + # untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + + @_torch.no_grad() + def log_cylind_lens_operator_x( + self, + writer: SummaryWriter, + tag: str, + global_step: Optional[int] = None, + ): + # 1. Apply exp to get the wrapped phase as it would be physically seen + # This ensures values outside [-pi, pi] wrap correctly + complex_op = _torch.exp(-1j * self._propagator_cylind_lens._operator_X_phi) + wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π] + + # 2. Normalize for Image Visualization [0, 1] + phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi) + + # 3. Log as a 1-pixel high row + # Shape: [1, 1, Width] + phase_row = phase_normalized.unsqueeze(0).unsqueeze(0) + writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW') + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}') + ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}") + ax.set_xlabel("Pixel Index") + ax.set_ylabel("Phase (rad)") + ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5]) + ax.grid(True, linestyle='--', alpha=0.6) + + + # Send the figure to the "Plots" or "Images" tab in TensorBoard + writer.add_figure(f"{tag}/phase_profile", fig, global_step) + plt.close(fig) # Important: prevent memory leaks + + +class TrainableFocalDistLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableFocalDistCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_cylind_lens: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + # TODO data should be at least two seq length. For one we get + # untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + + @_torch.no_grad() + def log_cylind_lens_operator_x( + self, + writer: SummaryWriter, + tag: str, + global_step: Optional[int] = None, + ): + # 1. Apply exp to get the wrapped phase as it would be physically seen + # This ensures values outside [-pi, pi] wrap correctly + lens = self._propagator_cylind_lens + writer.add_scalar(f"{tag}/focal_distance", lens._distance.detach().cpu().numpy(), global_step) + + complex_op = _torch.exp(-1j * lens._K / lens._distance * lens._linspace_by_x**2) + wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π] + + # 2. Normalize for Image Visualization [0, 1] + phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi) + + # 3. Log as a 1-pixel high row + # Shape: [1, 1, Width] + phase_row = phase_normalized.unsqueeze(0).unsqueeze(0) + writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW') + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}') + ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}\nFocal distance: {lens._distance.detach().cpu().numpy()}") + ax.set_xlabel("Pixel Index") + ax.set_ylabel("Phase (rad)") + ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5]) + ax.grid(True, linestyle='--', alpha=0.6) + + # Send the figure to the "Plots" or "Images" tab in TensorBoard + writer.add_figure(f"{tag}/phase_profile", fig, global_step) + plt.close(fig) # Important: prevent memory leaks + + +class TrainableScalarAndLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_two: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + self.k = nn.Parameter(torch.tensor(1)) def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: """ @@ -114,12 +644,9 @@ class OpticalMul(_nn.Module): """ vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - if self.trainable_cylind_lens: - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_between(vec_field) - else: - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_two(vec_field * mat_field) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) - return self.prepare_out(vec_field) \ No newline at end of file + return self.k * self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index eaa061f..3079d13 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -7,6 +7,7 @@ from typing import Tuple as _Tuple, Sequence as _Sequence from abc import ABC as _ABC import collections as _collections +import copy as _copy class Propagator(_ABC, _nn.Module): """ @@ -16,18 +17,12 @@ class Propagator(_ABC, _nn.Module): operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ - def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False): + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) - if trainable: - self._operator_X = _nn.Parameter(operator_X) - self._operator_Y = _nn.Parameter(operator_Y) - self._trainable = trainable - else: - self.register_buffer('_operator_X', operator_X, persistent=True) - self.register_buffer('_operator_Y', operator_Y, persistent=True) - self._trainable = trainable + self.register_buffer('_operator_X', operator_X, persistent=True) + self.register_buffer('_operator_Y', operator_Y, persistent=True) @property def operator_X(self) -> _torch.Tensor: @@ -98,7 +93,14 @@ class Propagator(_ABC, _nn.Module): """ return self.cat(propagator) - def forward(self, field: _torch.Tensor) -> _torch.Tensor: + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: """ Метод распространения светового поля в среде. @@ -109,13 +111,23 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - if self._trainable: - return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X) - else: - return self.operator_Y @ field @ self.operator_X + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1]) + + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X def __repr__(self): - return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" + return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ @@ -145,7 +157,7 @@ class PropagatorCrossLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase): """ Конструктор класса скрещенной линзы. @@ -155,18 +167,17 @@ class PropagatorCrossLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) - super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y), - trainable) + super(PropagatorCrossLens, self).__init__( + _torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) -class PropagatorСylindLens(PropagatorLens): +class PropagatorCylindLens(PropagatorLens): """ Класс распространения света в цилиндрической линзе, представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, - trainable = False): + config: _ConfigOpticBase): """ Конструктор класса цилиндрической линзы. @@ -176,14 +187,10 @@ class PropagatorСylindLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) - if trainable: - super(PropagatorСylindLens, self).__init__(operator_X, - operator_Y, - trainable) - else: - super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y), - trainable) + super(PropagatorCylindLens, self).__init__( + _torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) + class PropagatorSinc(Propagator): """ @@ -192,7 +199,7 @@ class PropagatorSinc(Propagator): """ def __init__(self, first_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase): """ Конструктор класса распространения в свободном пространстве. @@ -204,7 +211,7 @@ class PropagatorSinc(Propagator): operator_X, operator_Y = self.__get_operators(first_plane, second_plane, config) - super(PropagatorSinc, self).__init__(operator_X, operator_Y, trainable) + super(PropagatorSinc, self).__init__(operator_X, operator_Y) def __get_operator_for_dim(self, pixel_size_in: float, @@ -237,4 +244,137 @@ class PropagatorSinc(Propagator): second_plane.pixel_size_by_y, difference_y, config) - return operator_X, operator_Y \ No newline at end of file + return operator_X, operator_Y + + +####################################################################################################################### + +class PropagatorTrainableCylindLens(_ABC, _nn.Module): + """ + Класс распространения света в обучаемой цилиндрической линзе, + представленной тонким прозрачным оптическим элементом. + """ + def __init__(self, + plane: _ConfigDesignPlane, + config: _ConfigOpticBase + ): + super().__init__() + # non smooth profile after training. better to train only focal length? + self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2) + operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) + operator_Y = _torch.view_as_real(operator_Y) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + + @property + def operator_X(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси абсцисс + """ + return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi)) + + @property + def operator_Y(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + return _torch.view_as_complex(self._operator_Y) + + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: + """ + Метод распространения светового поля в среде. + + Args: + field: распределение комплексной амплитуды светового поля. + + Returns: + Распределение комплексной амплитуды светового поля, + после распространения. + """ + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X + + +class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module): + """ + Класс распространения света в обучаемой цилиндрической линзе, + представленной тонким прозрачным оптическим элементом. + """ + def __init__(self, + plane: _ConfigDesignPlane, + config: _ConfigOpticBase + ): + super().__init__() + self._distance = _nn.Parameter(_torch.tensor(config.distance)) + self.register_buffer('_K', _torch.tensor(config.K), persistent=True) + self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True) + operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) + operator_Y = _torch.view_as_real(operator_Y) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + + @property + def operator_X(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси абсцисс + """ + return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2)) + + @property + def operator_Y(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + return _torch.view_as_complex(self._operator_Y) + + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: + """ + Метод распространения светового поля в среде. + + Args: + field: распределение комплексной амплитуды светового поля. + + Returns: + Распределение комплексной амплитуды светового поля, + после распространения. + """ + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py index af1e518..0cd65d3 100644 --- a/src/optics_char_gpt2_nokoef.py +++ b/src/optics_char_gpt2_nokoef.py @@ -179,6 +179,8 @@ class OpticGPT2NOKoef(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, diff --git a/src/train_gpt2.py b/src/train_gpt2.py new file mode 100644 index 0000000..edee343 --- /dev/null +++ b/src/train_gpt2.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_focal_dist_lens_64.py b/src/train_optics_trainable_focal_dist_lens_64.py new file mode 100644 index 0000000..a0aac71 --- /dev/null +++ b/src/train_optics_trainable_focal_dist_lens_64.py @@ -0,0 +1,399 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndFocalDistLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer( + h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = 40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndFocalDistLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + stride = max(1, len(data) // 10000) + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, loss = model(x[None,...], y[None,...]) + losses.append(loss.item()) + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(np.mean(losses)) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) \ No newline at end of file diff --git a/src/train_optics_trainable_lens_128.py b/src/train_optics_trainable_lens_128.py new file mode 100644 index 0000000..799f312 --- /dev/null +++ b/src/train_optics_trainable_lens_128.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_256.py b/src/train_optics_trainable_lens_256.py new file mode 100644 index 0000000..d7df7f1 --- /dev/null +++ b/src/train_optics_trainable_lens_256.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_512.py b/src/train_optics_trainable_lens_512.py new file mode 100644 index 0000000..16bd46e --- /dev/null +++ b/src/train_optics_trainable_lens_512.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 10 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_64.py b/src/train_optics_trainable_lens_64.py new file mode 100644 index 0000000..b3f4c3c --- /dev/null +++ b/src/train_optics_trainable_lens_64.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file