You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpt2_optics/src/bert_optica_koef_newf.py

485 lines
24 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}"
from torch import nn
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta
from torchmetrics import AUROC
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import autograd
import optical_matrix_multiplication as omm
step = 1
pixel_size: float = 3.6e-6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device - ', device)
def new_formula(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1
tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2
device = tensor_1.device
A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0)
A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0)
B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0)
B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0)
max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений
max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений
max_B_pos = torch.max(B_pos)
max_B_neg = torch.max(B_neg)
zero_template = torch.zeros_like(
torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3]))
if max_A_pos > 0 and max_B_pos > 0:
t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos
else:
t1 = zero_template.clone().to(device)
if max_A_pos > 0 and max_B_neg > 0:
t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg
else:
t2 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_pos > 0:
t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos
else:
t3 = zero_template.clone().to(device)
if max_A_neg > 0 and max_B_neg > 0:
t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg
else:
t4 = zero_template.clone().to(device)
return (t1 - t2 - t3 + t4)[0,:,:,:]
comment = Path(__file__).stem # sys.argv[2]
checkpoint_file = None
logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir)
print("Chekpoints dir:", сhekpoints_dir)
writer = SummaryWriter(logs_dir)
Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"):
checkpoint = {
'encoder': {
'state_dict': encoder.state_dict(),
**{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
},
'model': {
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
},
'epoch': epoch,
'optimizer': {
'state_dict': optimizer.state_dict(),
},
'loss': loss,
'rocauc': rocauc,
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
}
path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth"
torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}")
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
#optimizer.load_state_dict(checkpoint['optimizer'])
encoder.load_state_dict(checkpoint['encoder']['state_dict'])
model.load_state_dict(checkpoint['model']['state_dict'])
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.train_uniq_client_ids_path = train_uniq_client_ids_path
self.test_uniq_client_ids_path = test_uniq_client_ids_path
if Path(self.train_uniq_client_ids_path).exists():
self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.train_uniq_client_ids_path)
else:
raise Exception(f"No {self.train_uniq_client_ids_path}")
if Path(self.test_uniq_client_ids_path).exists():
self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
print("Loaded", self.test_uniq_client_ids_path)
else:
raise Exception(f"No {self.test_uniq_client_ids_path}")
self.features_df = pd.read_parquet(features_path)
self.targets_df = pd.read_csv(targets_path)
self.uniq_client_ids = self.features_df.id.unique()
self.max_user_history = self.features_df.rn.max()
self.id_columns = ['id', 'rn']
self.cat_columns = [
'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
'fclose_flag'
]
self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns))
# make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
self.features_df = self.features_df.set_index('id')
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
def get_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
@property
def device(self):
return next(self.parameters()).device
def forward(self, cat_features_batch, num_features_batch, targets_batch):
cat_features_batch = cat_features_batch.to(self.device)
num_features_batch = num_features_batch.to(self.device)
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
inputs = self.proj(embed_tensor)
targets = targets_batch.to(self.device)
return inputs, targets
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * new_formula(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1
self.max_seq_len = max_seq_len + self.cls_token.shape[1]
print(h_dim, self.max_seq_len)
if max_seq_len < 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=False)
)
if max_seq_len >= 512:
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = self.max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * self.max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = self.max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * self.max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
start_prep_time = datetime.now()
credit_train_dataset = CreditProductsDataset(
features_path="/wd/finbert_data/train_data",
targets_path="/wd/finbert_data/train_target .csv",
train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv",
test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv",
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
h_dim = 64
category_feature_dim = 8
layers_num = 2
num_heads = 1
class_num = 1
dropout_rate = 0.1
encoder = Encoder(
cat_columns=credit_train_dataset.cat_columns,
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim,
).to(device)
model = BertClassifier(
layers_num=layers_num,
num_heads=num_heads,
h_dim=h_dim,
class_num=class_num,
max_seq_len=credit_train_dataset.max_user_history,
dropout_rate = dropout_rate
).to(device)
print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}')
model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}'
writer.add_text('model', model_description, 0)
optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters()))
if checkpoint_file is not None:
load_checkpoint(checkpoint_file)
positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
pos_weight = negative_counts / (positive_counts + 1e-15)
print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
epochs = 50
batch_size = 300
batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size
# for parallel data selection
class WrapperDataset(Dataset):
def __init__(self, credit_dataset, encoder, batch_size):
self.credit_dataset = credit_dataset
self.encoder = encoder
self.batch_size = batch_size
def __len__(self):
return len(self.credit_dataset.uniq_client_ids) // self.batch_size
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size)
dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
val_auroc = AUROC(task='binary')
test_auroc = AUROC(task='binary')
def test(epoch):
model.eval()
encoder.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets)
outputs = model(inputs)
test_auroc.update(outputs, targets.long())
print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40)
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40)
print()
start_time = datetime.now()
print("Started at:", start_time)
last_display_time = start_time
last_checkpoint_time = start_time
for epoch in range(epochs):
test(epoch)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
# with autograd.detect_anomaly(True):
model.train()
encoder.train()
optimizer.train()
inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
current_time = datetime.now()
if current_time - last_display_time > timedelta(seconds=1):
writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id)
optimizer.step()
optimizer.zero_grad()
if current_time - last_display_time > timedelta(seconds=1):
model.eval()
encoder.eval()
optimizer.eval()
last_display_time = current_time
writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id)
print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40)
if current_time - last_checkpoint_time > timedelta(hours=4):
last_checkpoint_time = current_time
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epoch,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
test(epochs)
print()
save_checkpoint(
credit_dataset=credit_train_dataset,
encoder = encoder, model=model, optimizer=optimizer, epoch=epochs,
loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir)
writer.close()