import os import sys DEVICE_IDX = os.environ['CUDA_VISIBLE_DEVICES'] from torch import nn import torch import pandas as pd import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime, timedelta from torchmetrics import AUROC from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from pathlib import Path import schedulefree from einops import rearrange, repeat import torch.nn.functional as F DEVICE = "cuda" comment = sys.argv[1] logs_dir = f'runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{DEVICE_IDX}_{comment}/' сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{DEVICE_IDX}_{comment}/' Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True) print("Logs dir:", logs_dir) print("Chekpoints dir:", сheсkpoints_dir) writer = SummaryWriter(logs_dir) script_snapshot_path = Path(logs_dir + "bert_training.py") script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script script_snapshot_path.chmod(0o400) # with read-only permission def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir): checkpoint = { 'encoder': { 'state_dict': encoder.state_dict(), **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} }, 'model': { 'state_dict': model.state_dict(), **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} }, 'optimizer': { 'state_dict': optimizer.state_dict(), }, 'epoch': epoch, 'loss': loss, 'rocauc': rocauc, 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path } path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth" torch.save(checkpoint, path) print(f"\nCheckpoint saved to {path}") ################################################################################################# class CreditProductsDataset: def __init__(self, features_path, targets_path, train_test_split_ratio=0.9, train_uniq_client_ids_path=None, test_uniq_client_ids_path=None, dropout_rate=0.0 ): self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) if Path(self.train_uniq_client_ids_path).exists(): self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values print("Loaded", self.train_uniq_client_ids_path) else: raise Exception(f"No {self.train_uniq_client_ids_path}") if Path(self.test_uniq_client_ids_path).exists(): self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values print("Loaded", self.test_uniq_client_ids_path) else: raise Exception(f"No {self.test_uniq_client_ids_path}") assert(len(np.intersect1d(self.train_uniq_client_ids, self.test_uniq_client_ids)) == 0), "Train contains test examples." self.features_df = pd.read_parquet(features_path) self.targets_df = pd.read_csv(targets_path) self.uniq_client_ids = self.features_df.id.unique() self.max_user_history = self.features_df.rn.max() self.id_columns = ['id', 'rn'] self.cat_columns = [ 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', 'fclose_flag', 'pre_loans5', 'pre_loans6090', 'pre_loans530', 'pre_loans90', 'pre_loans3060' ] self.num_columns = ['pre_loans5'] # TODO empty list get DatParallel to crash # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) self.features_df = self.features_df.set_index('id') self.targets_df = self.targets_df.set_index('id') self.targets_df = self.targets_df.sort_index() self.user_seq_lengths = self.features_df.index.value_counts().sort_index() self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) def get_batch(self, batch_size=4): sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(1-self.dropout_rate) # argument is keep_prob num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(1-self.dropout_rate) # argument is keep_prob targets_batch = self.targets[sampled_ids] return cat_features_batch, num_features_batch, targets_batch def get_test_batch_iterator(self, batch_size=4): for i in range(0, len(self.test_uniq_client_ids), batch_size): ids = self.test_uniq_client_ids[i:i+batch_size] cat_features_batch = self.cat_features[ids] num_features_batch = self.num_features[ids] targets_batch = self.targets[ids] yield cat_features_batch, num_features_batch, targets_batch class Encoder(nn.Module): def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, dropout_rate=0.5): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) def forward(self, cat_features_batch, num_features_batch, targets_batch): cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) inputs = self.proj(embed_tensor) targets = targets_batch return inputs, targets # RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 class RoPE(nn.Module): def __init__(self, dim, max_seq_len=512): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_seq_len).type_as(inv_freq) freqs = torch.einsum('i,j->ij', t, inv_freq) self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) def rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def forward(self, x, offset=0): seq_len = x.size(1) emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) cos = emb.cos() sin = emb.sin() return (x * cos) + (self.rotate_half(x) * sin) # Transformers without Normalization https://jiachenzhu.github.io/DyT/ class DyT(nn.Module): def __init__(self, num_features, alpha_init_value=0.5): super().__init__() self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) def forward(self, x): x = torch.tanh(self.alpha * x) return x * self.weight + self.bias # Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 # NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 class TransformerLayer(nn.Module): def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.q_proj = nn.Linear(h_dim, h_dim) self.k_proj = nn.Linear(h_dim, h_dim) self.v_proj = nn.Linear(h_dim, h_dim) self.o_proj = nn.Linear(h_dim, h_dim) self.ff1 = nn.Linear(h_dim, 4*h_dim) self.ff2 = nn.Linear(4*h_dim, h_dim) self.ln1 = DyT(h_dim) self.ln2 = DyT(h_dim) self.ln3 = DyT(max_seq_len) self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) def split_to_heads(self, x, B, T, H): return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x def gather_heads(self, x, B, T, H): return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x def attention(self, x): q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) v = self.split_to_heads(self.v_proj(x), *x.shape) scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) attention = self.ln3(F.dropout1d(scores, p=self.dropout_rate)) return self.o_proj(self.gather_heads(attention @ v, *x.shape)) def forward(self, x): x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x class BertClassifier(nn.Module): def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.cls_token = nn.Parameter(torch.randn(1,1,h_dim)) self.max_seq_len = max_seq_len + self.cls_token.shape[1] self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.Dropout(0.1), nn.GELU(), nn.Linear(h_dim, class_num)) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) def forward(self, x): x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) x = x + self.pos_embeds[:, :x.shape[1], :] for l in self.layers: x = l(x) x = self.classifier_head(x[:,0,:]) return x[:,:] if self.class_num > 1 else x[:,0] class Model(nn.Module): def __init__(self, encoder, classifier): super().__init__() self.encoder = encoder self.classifier = classifier def forward(self, cat_inputs, num_inputs, targets): inputs, targets = self.encoder(cat_inputs, num_inputs, targets) return self.classifier(inputs), targets h_dim = 64 category_feature_dim = 8 layers_num = 6 num_heads = 2 class_num = 1 dropout_rate = 0.4 epochs = 800 batch_size = 30000 start_prep_time = datetime.now() credit_train_dataset = CreditProductsDataset( features_path="/wd/data/train_data/", targets_path="/wd/data/train_target.csv", # train_uniq_client_ids_path="/wd/train_uniq_client_ids.csv", # test_uniq_client_ids_path="/wd/test_uniq_client_ids.csv", # train_uniq_client_ids_path="/wd/dima_train_ids.csv", # test_uniq_client_ids_path="/wd/dima_test_ids.csv", # train_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_train_ids.csv", # test_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_test_ids.csv", train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv", test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv", dropout_rate=dropout_rate ) batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size print(f"Dataset preparation time: {datetime.now() - start_prep_time}") encoder = Encoder( cat_columns=credit_train_dataset.cat_columns, num_columns=credit_train_dataset.num_columns, cat_features_max_id=credit_train_dataset.cat_features.max(), category_feature_dim=category_feature_dim, out_dim=h_dim, dropout_rate=dropout_rate ) classifier = BertClassifier( layers_num=layers_num, num_heads=num_heads, h_dim=h_dim, class_num=class_num, max_seq_len=credit_train_dataset.max_user_history, dropout_rate = dropout_rate ) model = Model(encoder=encoder, classifier=classifier) model = torch.nn.DataParallel(model, device_ids=[int(idx) for idx in DEVICE_IDX.split(",")]).to(DEVICE) # The Road Less Scheduled https://arxiv.org/html/2405.15682v4 optimizer = schedulefree.AdamWScheduleFree(model.parameters()) positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts pos_weight = negative_counts / (positive_counts + 1e-15) print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) # for parallel data selection class WrapperDataset(Dataset): def __init__(self, credit_dataset, encoder, batch_size): self.credit_dataset = credit_dataset self.encoder = encoder self.batch_size = batch_size def __len__(self): return len(self.credit_dataset.uniq_client_ids) // self.batch_size def __getitem__(self, idx): cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) return cat_inputs, num_inputs, targets training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8*2, pin_memory=True) test_auroc = AUROC(task='binary') def test(epoch): model.eval() optimizer.eval() with torch.no_grad(): test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): test_cat_inputs, test_num_inputs, test_targets = [x.to("cuda", non_blocking=True) for x in [test_cat_inputs, test_num_inputs, test_targets]] outputs, targets = model(test_cat_inputs, test_num_inputs, test_targets) test_auroc.update(outputs, targets.long()) print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*2) writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*2) print() start_time = datetime.now() print("Started at:", start_time) last_display_time = start_time last_checkpoint_time = start_time try: for epoch in range(epochs): test(epoch) for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): model.train() optimizer.train() optimizer.zero_grad() cat_inputs, num_inputs, targets = [x.to("cuda", non_blocking=True) for x in [cat_inputs[0], num_inputs[0], targets[0]]] outputs, targets = model(cat_inputs, num_inputs, targets) loss = criterion(outputs, targets) loss.backward() optimizer.step() current_time = datetime.now() if current_time - last_display_time > timedelta(seconds=1): last_display_time = current_time writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item():.6f} {comment}", end = " "*2) if current_time - last_checkpoint_time > timedelta(hours=4): last_checkpoint_time = current_time test(epoch) save_checkpoint( credit_dataset=credit_train_dataset, encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, loss=loss.item(), rocauc=test_auroc.compute().item(), сheсkpoints_dir=сheсkpoints_dir) except KeyboardInterrupt: print() finally: test(epoch+1) save_checkpoint( credit_dataset=credit_train_dataset, encoder = encoder, model=model, optimizer=optimizer, epoch=epoch+1, loss=loss.item(), rocauc=test_auroc.compute().item(), сheсkpoints_dir=сheсkpoints_dir) writer.close()