| 
						
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -0,0 +1,452 @@
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import os
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import sys
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# DEVICE_IDX = sys.argv[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# os.environ['CUDA_VISIBLE_DEVICES'] = f"{DEVICE_IDX}"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch import nn
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import torch
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import pandas as pd
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import numpy as np
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.utils.tensorboard import SummaryWriter
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from datetime import datetime, timedelta
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torchmetrics import AUROC
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.nn.utils.rnn import pad_sequence
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.utils.data import Dataset, DataLoader
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from pathlib import Path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import schedulefree
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from einops import rearrange, repeat
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import torch.nn.functional as F
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.nn.parallel import DistributedDataParallel as DDP
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch.distributed.fsdp.wrap import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    size_based_auto_wrap_policy,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    enable_wrap,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    wrap,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import functools
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from layers import *
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    checkpoint = {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'encoder': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': encoder.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'model': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': model.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'optimizer': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': optimizer.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'epoch': epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'loss': loss,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'rocauc': rocauc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # if torch.distributed.get_rank() == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.save(checkpoint, path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(f"\nCheckpoint saved to {path}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################## Dataset #########################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class CreditProductsDataset:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        features_path, targets_path, train_test_split_ratio=0.9,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        train_uniq_client_ids_path=None, test_uniq_client_ids_path=None,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        dropout_rate=0.0
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if Path(self.train_uniq_client_ids_path).exists():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            print("Loaded", self.train_uniq_client_ids_path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        else: 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            raise Exception(f"No {self.train_uniq_client_ids_path}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if Path(self.test_uniq_client_ids_path).exists():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            print("Loaded", self.test_uniq_client_ids_path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        else: 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            raise Exception(f"No {self.test_uniq_client_ids_path}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert(len(np.intersect1d(self.train_uniq_client_ids, self.test_uniq_client_ids)) == 0), "Train contains test examples."
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.features_df = pd.read_parquet(features_path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.targets_df = pd.read_csv(targets_path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.uniq_client_ids = self.features_df.id.unique()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.max_user_history = self.features_df.rn.max()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.id_columns = ['id', 'rn']
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_columns = [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'fclose_flag',
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'pre_loans5', 'pre_loans6090', 'pre_loans530', 'pre_loans90', 'pre_loans3060'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_columns = ['pre_loans5'] # TODO empty list get DatParallel to crash
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_cardinalities_integral = self.cat_cardinalities.cumsum()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.features_df = self.features_df.set_index('id')
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.user_seq_lengths = self.features_df.index.value_counts().sort_index()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.targets_df = self.targets_df.set_index('id')
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.targets_df = self.targets_df.sort_index()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def get_batch(self, batch_size=4):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_features_batch = self.cat_features[sampled_ids]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        num_features_batch = self.num_features[sampled_ids]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if self.dropout_rate > 0.0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            cat_features_batch *= torch.empty_like(cat_features_batch).bernoulli_(1-self.dropout_rate) # argument is keep_prob
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            num_features_batch *= torch.empty_like(num_features_batch).bernoulli_(1-self.dropout_rate) # argument is keep_prob
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        targets_batch = self.targets[sampled_ids]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return cat_features_batch, num_features_batch, targets_batch
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def get_test_batch_iterator(self, batch_size=4):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for i in range(0, len(self.test_uniq_client_ids), batch_size):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            ids = self.test_uniq_client_ids[i:i+batch_size]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            cat_features_batch = self.cat_features[ids]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            num_features_batch = self.num_features[ids] 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            targets_batch = self.targets[ids]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            yield cat_features_batch, num_features_batch, targets_batch
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# for parallel data selection 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class WrapperDataset(Dataset):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, credit_dataset, encoder, batch_size, datasets_per_epoch):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.credit_dataset = credit_dataset
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.encoder = encoder
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.batch_size = batch_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_batches = len(self.credit_dataset.train_uniq_client_ids) // self.batch_size // torch.distributed.get_world_size() * datasets_per_epoch
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __len__(self):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return self.num_batches
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __getitem__(self, idx):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_inputs, num_inputs, targets = self.credit_dataset.get_batch(batch_size=self.batch_size)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return cat_inputs, num_inputs, targets
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				##################################### Model ###########################################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class Encoder(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, dropout_rate=0.0):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns)))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, cat_features_batch, num_features_batch, targets_batch):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        inputs = self.proj(embed_tensor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if self.dropout_rate > 0.0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            inputs = F.dropout1d(inputs, p=self.dropout_rate)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        targets = targets_batch
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return inputs, targets
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class RoPE(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, dim, max_seq_len=512):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        t = torch.arange(max_seq_len).type_as(inv_freq)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        freqs = torch.einsum('i,j->ij', t, inv_freq)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def rotate_half(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x1, x2 = x.chunk(2, dim=-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return torch.cat((-x2, x1), dim=-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x, offset=0):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        seq_len = x.size(1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cos = emb.cos()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        sin = emb.sin()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return (x * cos) + (self.rotate_half(x) * sin)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# Transformers without Normalization https://jiachenzhu.github.io/DyT/
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class DyT(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, num_features, alpha_init_value=0.5):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.weight = nn.Parameter(torch.ones(num_features))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.bias = nn.Parameter(torch.zeros(num_features))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.tanh(self.alpha * x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x * self.weight + self.bias
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class DyC(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, num_features, alpha_init_value=0.5):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.weight = nn.Parameter(torch.ones(num_features))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.bias = nn.Parameter(torch.zeros(num_features))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.clip(self.alpha * x, min=-1, max=1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x * self.weight + self.bias
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# from layers import ChebyKANLayer
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class TransformerLayer(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.q_proj = nn.Linear(h_dim, h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.k_proj = nn.Linear(h_dim, h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.v_proj = nn.Linear(h_dim, h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.o_proj = nn.Linear(h_dim, h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.ff1 = nn.Linear(h_dim, 4*h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.ff2 = nn.Linear(4*h_dim, h_dim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.ln1 = DyC(h_dim) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.ln2 = DyC(h_dim) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.ln3 = DyC(max_seq_len) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def split_to_heads(self, x, B, T, H):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def gather_heads(self, x, B, T, H):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # how to check that attention is actually make some difference
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def attention(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        v = self.split_to_heads(self.v_proj(x), *x.shape)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # attention = nn.functional.softmax(F.dropout1d(scores, p=self.dropout_rate), dim=2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # attention = self.ln3(F.dropout1d(scores, p=self.dropout_rate))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        attention = self.ln3(scores)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return self.o_proj(self.gather_heads(attention @ v, *x.shape))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class BertClassifier(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cls_token = nn.Parameter(torch.randn(1,1,h_dim)) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.max_seq_len = max_seq_len + self.cls_token.shape[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x + self.pos_embeds[:, :x.shape[1], :]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for l in self.layers:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x = l(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = self.classifier_head(x[:,0,:])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x[:,:] if self.class_num > 1 else x[:,0]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class Model(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, encoder, classifier):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.encoder = encoder
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.classifier = classifier
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, cat_inputs, num_inputs, targets):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        inputs, targets = self.encoder(cat_inputs, num_inputs, targets)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return self.classifier(inputs), targets
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, writer):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        model.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        optimizer.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        with torch.no_grad():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_cat_inputs, test_num_inputs, test_targets = [x.to(device_id, non_blocking=True) for x in [test_cat_inputs, test_num_inputs, test_targets]]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                outputs, targets = model(test_cat_inputs, test_num_inputs, test_targets)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_auroc.update(outputs, targets.long())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if torch.distributed.get_rank() == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################### Training ################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				h_dim = 32
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				category_feature_dim = 8
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				layers_num = 6
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				num_heads = 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class_num = 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				dataset_dropout_rate = 0.4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				encoder_dropout_rate = 0.0
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				classifier_dropout_date = 0.4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				epochs = 500
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				batch_size = 2000
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				datasets_per_epoch = 5
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				num_workers = 10
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				if __name__ == "__main__":
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    env_dict = {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        key: os.environ[key]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.distributed.init_process_group(backend="nccl")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        f"[{os.getpid()}] world_size = {torch.distributed.get_world_size()}, "
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        + f"rank = {torch.distributed.get_rank()}, backend={torch.distributed.get_backend()}"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    rank = torch.distributed.get_rank()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    world_size = torch.distributed.get_world_size()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # DEVICE = "cuda"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        comment = sys.argv[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        logs_dir = f'runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        writer = SummaryWriter(logs_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path = Path(logs_dir + "bert_training_ddp.py")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print("Logs dir:", logs_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print("Chekpoints dir:", сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path.chmod(0o400) # with read-only permission
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    start_prep_time = datetime.now()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    credit_train_dataset = CreditProductsDataset(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        features_path="/wd/data/train_data/", 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        targets_path="/wd/data/train_target.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # train_uniq_client_ids_path="/wd/train_uniq_client_ids.csv", 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # test_uniq_client_ids_path="/wd/test_uniq_client_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # train_uniq_client_ids_path="/wd/dima_train_ids.csv",          
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # test_uniq_client_ids_path="/wd/dima_test_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # train_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_train_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # test_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_test_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        dropout_rate=dataset_dropout_rate
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0: print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    encoder = Encoder(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_columns=credit_train_dataset.cat_columns,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        num_columns=credit_train_dataset.num_columns, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_features_max_id=credit_train_dataset.cat_features.max(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        category_feature_dim=category_feature_dim, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        out_dim=h_dim,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        dropout_rate=encoder_dropout_rate
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    classifier = BertClassifier(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        layers_num=layers_num, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        num_heads=num_heads,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        h_dim=h_dim, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        class_num=class_num, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        max_seq_len=credit_train_dataset.max_user_history,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        dropout_rate = classifier_dropout_date
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    device_id = int(os.environ["LOCAL_RANK"])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    model = Model(encoder=encoder, classifier=classifier).to(f"cuda:{device_id}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0: print(f"Model parameters count: ", sum(p.numel() for p in model.parameters()))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    model = DDP(model, device_ids=[device_id])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # my_auto_wrap_policy = functools.partial(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    #     size_based_auto_wrap_policy, min_num_params=20000
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # The Road Less Scheduled https://arxiv.org/html/2405.15682v4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    optimizer = schedulefree.AdamWScheduleFree(model.parameters())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    pos_weight = negative_counts / (positive_counts + 1e-15) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0: print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size, datasets_per_epoch=datasets_per_epoch)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # number of batches to go through dataset once
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    batches_per_epoch = len(training_data)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0: print(f"Number of batches per epoch: {batches_per_epoch}, Number of datasets per epoch : {datasets_per_epoch}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    test_auroc = AUROC(task='binary', sync_on_compute=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    start_time = datetime.now()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    if rank == 0: print("Started at:", start_time)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    last_display_time = start_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    last_checkpoint_time = start_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ddp_loss = torch.zeros(1).to(rank) # quickly goes to nan TODO debug. Are all replicas training properly?
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    try:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for epoch in range(epochs):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer=writer if rank==0 else None)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                model.train()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                optimizer.train()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                optimizer.zero_grad()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                cat_inputs, num_inputs, targets = [x.to(device_id, non_blocking=True) for x in [cat_inputs[0], num_inputs[0], targets[0]]]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                outputs, targets = model(cat_inputs, num_inputs, targets)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss = criterion(outputs, targets) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss.backward()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                optimizer.step()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                ddp_loss[0] = loss.item()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                torch.distributed.all_reduce(ddp_loss, op=torch.distributed.ReduceOp.SUM)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                ddp_loss[0] /= world_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                current_time = datetime.now()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                if current_time - last_display_time > timedelta(seconds=1):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    last_display_time = current_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        writer.add_scalar(f'Loss', ddp_loss, epoch*batches_per_epoch+batch_id)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {ddp_loss[0]:.6f} {comment}", end = " "*2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                if current_time - last_checkpoint_time > timedelta(hours=8): # TODO ddp bug, deadlock
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    last_checkpoint_time = current_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                         writer=writer if rank==0 else None)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    rocauc = test_auroc.compute().item()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        # TODO ddp bug, deadlock
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    torch.distributed.barrier()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    except KeyboardInterrupt:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    finally: 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test(epoch+1, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        writer=writer if rank==0 else None)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        rocauc = test_auroc.compute().item()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer.close()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        torch.distributed.destroy_process_group()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 |