| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -80,7 +80,7 @@ class CreditProductsDataset:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            'fclose_flag',
 | 
					 | 
					 | 
					 | 
					            'fclose_flag',
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            'pre_loans5', 'pre_loans6090', 'pre_loans530', 'pre_loans90', 'pre_loans3060'
 | 
					 | 
					 | 
					 | 
					            'pre_loans5', 'pre_loans6090', 'pre_loans530', 'pre_loans90', 'pre_loans3060'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        ]
 | 
					 | 
					 | 
					 | 
					        ]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.num_columns = ['pre_loans5'] # TODO empty list get DatParallel to crash
 | 
					 | 
					 | 
					 | 
					        self.num_columns = []
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        
 | 
					 | 
					 | 
					 | 
					        
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
 | 
					 | 
					 | 
					 | 
					        # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
 | 
					 | 
					 | 
					 | 
					        self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -102,43 +102,43 @@ class CreditProductsDataset:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.targets_df = self.targets_df.sort_index()
 | 
					 | 
					 | 
					 | 
					        self.targets_df = self.targets_df.sort_index()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
 | 
					 | 
					 | 
					 | 
					        self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def get_batch(self, batch_size=4):
 | 
					 | 
					 | 
					 | 
					    def get_train_batch(self, batch_size=4):
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
 | 
					 | 
					 | 
					 | 
					        sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cat_features_batch = self.cat_features[sampled_ids]
 | 
					 | 
					 | 
					 | 
					        cat_features_batch = self.cat_features[sampled_ids] 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        num_features_batch = self.num_features[sampled_ids]
 | 
					 | 
					 | 
					 | 
					        num_features_batch = self.num_features[sampled_ids] 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if self.dropout_rate > 0.0:
 | 
					 | 
					 | 
					 | 
					        cat_features_batch *= torch.empty_like(cat_features_batch).bernoulli_(1-self.dropout_rate) # arg is keep_probability
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            cat_features_batch *= torch.empty_like(cat_features_batch).bernoulli_(1-self.dropout_rate) # argument is keep_prob
 | 
					 | 
					 | 
					 | 
					        num_features_batch *= torch.empty_like(num_features_batch).bernoulli_(1-self.dropout_rate)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            num_features_batch *= torch.empty_like(num_features_batch).bernoulli_(1-self.dropout_rate) # argument is keep_prob
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        targets_batch = self.targets[sampled_ids]
 | 
					 | 
					 | 
					 | 
					        targets_batch = self.targets[sampled_ids]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return cat_features_batch, num_features_batch, targets_batch
 | 
					 | 
					 | 
					 | 
					        return cat_features_batch, num_features_batch, targets_batch
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def get_test_batch_iterator(self, batch_size=4):
 | 
					 | 
					 | 
					 | 
					    def get_test_batch_iterator(self, batch_size=4):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        for i in range(0, len(self.test_uniq_client_ids), batch_size):
 | 
					 | 
					 | 
					 | 
					        for i in range(0, len(self.test_uniq_client_ids), batch_size):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            ids = self.test_uniq_client_ids[i:i+batch_size]
 | 
					 | 
					 | 
					 | 
					            ids = self.test_uniq_client_ids[i:i+batch_size]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            cat_features_batch = self.cat_features[ids]
 | 
					 | 
					 | 
					 | 
					            cat_features_batch = self.cat_features[ids] 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            num_features_batch = self.num_features[ids] 
 | 
					 | 
					 | 
					 | 
					            num_features_batch = self.num_features[ids] 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            targets_batch = self.targets[ids]
 | 
					 | 
					 | 
					 | 
					            targets_batch = self.targets[ids]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            yield cat_features_batch, num_features_batch, targets_batch
 | 
					 | 
					 | 
					 | 
					            yield cat_features_batch, num_features_batch, targets_batch
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# for parallel data selection 
 | 
					 | 
					 | 
					 | 
					# for parallel data selection 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class WrapperDataset(Dataset):
 | 
					 | 
					 | 
					 | 
					class WrapperDataset(Dataset):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def __init__(self, credit_dataset, encoder, batch_size, datasets_per_epoch):
 | 
					 | 
					 | 
					 | 
					    def __init__(self, credit_dataset, batch_size, datasets_per_epoch=1):
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.credit_dataset = credit_dataset
 | 
					 | 
					 | 
					 | 
					        self.credit_dataset = credit_dataset
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.encoder = encoder
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.batch_size = batch_size
 | 
					 | 
					 | 
					 | 
					        self.batch_size = batch_size
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.num_batches = len(self.credit_dataset.train_uniq_client_ids) // self.batch_size // torch.distributed.get_world_size() * datasets_per_epoch
 | 
					 | 
					 | 
					 | 
					        self.num_batches = len(self.credit_dataset.train_uniq_client_ids) \
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            // self.batch_size \
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            *  datasets_per_epoch
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def __len__(self):
 | 
					 | 
					 | 
					 | 
					    def __len__(self):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return self.num_batches
 | 
					 | 
					 | 
					 | 
					        return self.num_batches
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def __getitem__(self, idx):
 | 
					 | 
					 | 
					 | 
					    def __getitem__(self, idx):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cat_inputs, num_inputs, targets = self.credit_dataset.get_batch(batch_size=self.batch_size)
 | 
					 | 
					 | 
					 | 
					        cat_inputs, num_inputs, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return cat_inputs, num_inputs, targets
 | 
					 | 
					 | 
					 | 
					        return cat_inputs, num_inputs, targets
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					##################################### Model ###########################################################################################
 | 
					 | 
					 | 
					 | 
					##################################### Model ###########################################################################################
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class Encoder(nn.Module):
 | 
					 | 
					 | 
					 | 
					class Encoder(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, dropout_rate=0.0):
 | 
					 | 
					 | 
					 | 
					    def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        super().__init__()
 | 
					 | 
					 | 
					 | 
					        super().__init__()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
					 | 
					 | 
					 | 
					        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
 | 
					 | 
					 | 
					 | 
					        self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -150,14 +150,10 @@ class Encoder(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def forward(self, cat_features_batch, num_features_batch, targets_batch):
 | 
					 | 
					 | 
					 | 
					    def forward(self, cat_features_batch, num_features_batch, targets_batch):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
 | 
					 | 
					 | 
					 | 
					        cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
 | 
					 | 
					 | 
					 | 
					        cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
 | 
					 | 
					 | 
					 | 
					        num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts        
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
 | 
					 | 
					 | 
					 | 
					        embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        inputs = self.proj(embed_tensor)
 | 
					 | 
					 | 
					 | 
					        inputs = self.proj(embed_tensor)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if self.dropout_rate > 0.0:
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            inputs = F.dropout1d(inputs, p=self.dropout_rate)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        targets = targets_batch
 | 
					 | 
					 | 
					 | 
					        targets = targets_batch
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return inputs, targets
 | 
					 | 
					 | 
					 | 
					        return inputs, targets
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
 | 
					 | 
					 | 
					 | 
					# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -192,18 +188,6 @@ class DyT(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        x = torch.tanh(self.alpha * x)
 | 
					 | 
					 | 
					 | 
					        x = torch.tanh(self.alpha * x)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return x * self.weight + self.bias
 | 
					 | 
					 | 
					 | 
					        return x * self.weight + self.bias
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class DyC(nn.Module):
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def __init__(self, num_features, alpha_init_value=0.5):
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        super().__init__()
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.weight = nn.Parameter(torch.ones(num_features))
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.bias = nn.Parameter(torch.zeros(num_features))
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def forward(self, x):
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        x = torch.clip(self.alpha * x, min=-1, max=1)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return x * self.weight + self.bias
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# from layers import ChebyKANLayer
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
 | 
					 | 
					 | 
					 | 
					# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
 | 
					 | 
					 | 
					 | 
					# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class TransformerLayer(nn.Module):
 | 
					 | 
					 | 
					 | 
					class TransformerLayer(nn.Module):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -216,9 +200,8 @@ class TransformerLayer(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.o_proj = nn.Linear(h_dim, h_dim)
 | 
					 | 
					 | 
					 | 
					        self.o_proj = nn.Linear(h_dim, h_dim)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.ff1 = nn.Linear(h_dim, 4*h_dim)
 | 
					 | 
					 | 
					 | 
					        self.ff1 = nn.Linear(h_dim, 4*h_dim)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.ff2 = nn.Linear(4*h_dim, h_dim)
 | 
					 | 
					 | 
					 | 
					        self.ff2 = nn.Linear(4*h_dim, h_dim)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.ln1 = DyC(h_dim) 
 | 
					 | 
					 | 
					 | 
					        self.ln1 = DyT(h_dim) 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.ln2 = DyC(h_dim) 
 | 
					 | 
					 | 
					 | 
					        self.ln2 = DyT(h_dim) 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.ln3 = DyC(max_seq_len) 
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
 | 
					 | 
					 | 
					 | 
					        self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def split_to_heads(self, x, B, T, H):
 | 
					 | 
					 | 
					 | 
					    def split_to_heads(self, x, B, T, H):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -227,15 +210,12 @@ class TransformerLayer(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def gather_heads(self, x, B, T, H):
 | 
					 | 
					 | 
					 | 
					    def gather_heads(self, x, B, T, H):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
 | 
					 | 
					 | 
					 | 
					        return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # how to check that attention is actually make some difference
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def attention(self, x):
 | 
					 | 
					 | 
					 | 
					    def attention(self, x):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
 | 
					 | 
					 | 
					 | 
					        q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
 | 
					 | 
					 | 
					 | 
					        k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        v = self.split_to_heads(self.v_proj(x), *x.shape)
 | 
					 | 
					 | 
					 | 
					        v = self.split_to_heads(self.v_proj(x), *x.shape)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) 
 | 
					 | 
					 | 
					 | 
					        scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        # attention = nn.functional.softmax(F.dropout1d(scores, p=self.dropout_rate), dim=2)
 | 
					 | 
					 | 
					 | 
					        attention = nn.functional.softmax(scores, dim=2)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        # attention = self.ln3(F.dropout1d(scores, p=self.dropout_rate))
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        attention = self.ln3(scores)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return self.o_proj(self.gather_heads(attention @ v, *x.shape))
 | 
					 | 
					 | 
					 | 
					        return self.o_proj(self.gather_heads(attention @ v, *x.shape))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def forward(self, x):
 | 
					 | 
					 | 
					 | 
					    def forward(self, x):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -271,20 +251,21 @@ class Model(nn.Module):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        inputs, targets = self.encoder(cat_inputs, num_inputs, targets)
 | 
					 | 
					 | 
					 | 
					        inputs, targets = self.encoder(cat_inputs, num_inputs, targets)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return self.classifier(inputs), targets
 | 
					 | 
					 | 
					 | 
					        return self.classifier(inputs), targets
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, writer):
 | 
					 | 
					 | 
					 | 
					def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_dataset, test_auroc, writer):
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        model.eval()
 | 
					 | 
					 | 
					 | 
					        model.eval()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        optimizer.eval()
 | 
					 | 
					 | 
					 | 
					        optimizer.eval()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        with torch.no_grad():
 | 
					 | 
					 | 
					 | 
					        with torch.no_grad():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
 | 
					 | 
					 | 
					 | 
					            test_iterator = credit_dataset.get_test_batch_iterator(batch_size=batch_size)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
 | 
					 | 
					 | 
					 | 
					            for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                test_cat_inputs, test_num_inputs, test_targets = [x.to(device_id, non_blocking=True) for x in [test_cat_inputs, test_num_inputs, test_targets]]
 | 
					 | 
					 | 
					 | 
					                test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                test_num_inputs = test_num_inputs.to("cuda", non_blocking=True)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                test_targets = test_targets.to("cuda", non_blocking=True)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                outputs, targets = model(test_cat_inputs, test_num_inputs, test_targets)
 | 
					 | 
					 | 
					 | 
					                outputs, targets = model(test_cat_inputs, test_num_inputs, test_targets)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                test_auroc.update(outputs, targets.long())
 | 
					 | 
					 | 
					 | 
					                test_auroc.update(outputs, targets.long())
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
 | 
					 | 
					 | 
					 | 
					                print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if torch.distributed.get_rank() == 0:
 | 
					 | 
					 | 
					 | 
					        writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
 | 
					 | 
					 | 
					 | 
					        print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) 
 | 
					 | 
					 | 
					 | 
					        print()
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print()
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					######################################### Training ################################################################
 | 
					 | 
					 | 
					 | 
					######################################### Training ################################################################
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -294,11 +275,10 @@ layers_num = 6
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					num_heads = 2
 | 
					 | 
					 | 
					 | 
					num_heads = 2
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class_num = 1
 | 
					 | 
					 | 
					 | 
					class_num = 1
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					dataset_dropout_rate = 0.4
 | 
					 | 
					 | 
					 | 
					dataset_dropout_rate = 0.4
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					encoder_dropout_rate = 0.0
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					classifier_dropout_date = 0.4
 | 
					 | 
					 | 
					 | 
					classifier_dropout_date = 0.4
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					epochs = 500
 | 
					 | 
					 | 
					 | 
					epochs = 500
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					batch_size = 2000
 | 
					 | 
					 | 
					 | 
					batch_size = 30000
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					datasets_per_epoch = 5
 | 
					 | 
					 | 
					 | 
					datasets_per_epoch = 1
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					num_workers = 10
 | 
					 | 
					 | 
					 | 
					num_workers = 10
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					comment = sys.argv[1]
 | 
					 | 
					 | 
					 | 
					comment = sys.argv[1]
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -316,12 +296,6 @@ start_prep_time = datetime.now()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					credit_train_dataset = CreditProductsDataset(
 | 
					 | 
					 | 
					 | 
					credit_train_dataset = CreditProductsDataset(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    features_path="/wd/data/train_data/", 
 | 
					 | 
					 | 
					 | 
					    features_path="/wd/data/train_data/", 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    targets_path="/wd/data/train_target.csv",
 | 
					 | 
					 | 
					 | 
					    targets_path="/wd/data/train_target.csv",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # train_uniq_client_ids_path="/wd/train_uniq_client_ids.csv", 
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # test_uniq_client_ids_path="/wd/test_uniq_client_ids.csv",
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # train_uniq_client_ids_path="/wd/dima_train_ids.csv",          
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # test_uniq_client_ids_path="/wd/dima_test_ids.csv",
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # train_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_train_ids.csv",
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # test_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_test_ids.csv",
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv",
 | 
					 | 
					 | 
					 | 
					    train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv",
 | 
					 | 
					 | 
					 | 
					    test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    dropout_rate=dataset_dropout_rate
 | 
					 | 
					 | 
					 | 
					    dropout_rate=dataset_dropout_rate
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -333,8 +307,7 @@ encoder = Encoder(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    num_columns=credit_train_dataset.num_columns, 
 | 
					 | 
					 | 
					 | 
					    num_columns=credit_train_dataset.num_columns, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    cat_features_max_id=credit_train_dataset.cat_features.max(),
 | 
					 | 
					 | 
					 | 
					    cat_features_max_id=credit_train_dataset.cat_features.max(),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    category_feature_dim=category_feature_dim, 
 | 
					 | 
					 | 
					 | 
					    category_feature_dim=category_feature_dim, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    out_dim=h_dim,
 | 
					 | 
					 | 
					 | 
					    out_dim=h_dim
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    dropout_rate=encoder_dropout_rate
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					)
 | 
					 | 
					 | 
					 | 
					)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					classifier = BertClassifier(
 | 
					 | 
					 | 
					 | 
					classifier = BertClassifier(
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -346,11 +319,8 @@ classifier = BertClassifier(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    dropout_rate = classifier_dropout_date
 | 
					 | 
					 | 
					 | 
					    dropout_rate = classifier_dropout_date
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					)
 | 
					 | 
					 | 
					 | 
					)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					device_id = int(os.environ["LOCAL_RANK"])
 | 
					 | 
					 | 
					 | 
					model = Model(encoder=encoder, classifier=classifier).to("cuda")
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					model = Model(encoder=encoder, classifier=classifier).to(f"cuda:{device_id}")
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(f"Model parameters count: ", sum(p.numel() for p in model.parameters()))
 | 
					 | 
					 | 
					 | 
					print(f"Model parameters count: ", sum(p.numel() for p in model.parameters()))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					model = DDP(model, device_ids=[device_id])
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# The Road Less Scheduled https://arxiv.org/html/2405.15682v4
 | 
					 | 
					 | 
					 | 
					# The Road Less Scheduled https://arxiv.org/html/2405.15682v4
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					optimizer = schedulefree.AdamWScheduleFree(model.parameters())
 | 
					 | 
					 | 
					 | 
					optimizer = schedulefree.AdamWScheduleFree(model.parameters())
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -359,15 +329,16 @@ positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
 | 
					 | 
					 | 
					 | 
					negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					pos_weight = negative_counts / (positive_counts + 1e-15) 
 | 
					 | 
					 | 
					 | 
					pos_weight = negative_counts / (positive_counts + 1e-15) 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
 | 
					 | 
					 | 
					 | 
					print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
 | 
					 | 
					 | 
					 | 
					criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size, datasets_per_epoch=datasets_per_epoch)
 | 
					 | 
					 | 
					 | 
					training_data = WrapperDataset(credit_dataset=credit_train_dataset, batch_size=batch_size, datasets_per_epoch=datasets_per_epoch)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
 | 
					 | 
					 | 
					 | 
					dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# number of batches to go through dataset once
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					batches_per_epoch = len(training_data)
 | 
					 | 
					 | 
					 | 
					batches_per_epoch = len(training_data)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(f"Number of batches per epoch: {batches_per_epoch}, Number of datasets per epoch : {datasets_per_epoch}")
 | 
					 | 
					 | 
					 | 
					print(f"Number of batches per epoch: {batches_per_epoch}, Number of datasets per epoch : {datasets_per_epoch}")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					test_auroc = AUROC(task='binary', sync_on_compute=True)
 | 
					 | 
					 | 
					 | 
					test_auroc = AUROC(task='binary')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					start_time = datetime.now()
 | 
					 | 
					 | 
					 | 
					start_time = datetime.now()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print("Started at:", start_time)
 | 
					 | 
					 | 
					 | 
					print("Started at:", start_time)
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -375,45 +346,82 @@ last_display_time = start_time
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					last_checkpoint_time = start_time
 | 
					 | 
					 | 
					 | 
					last_checkpoint_time = start_time
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					try:
 | 
					 | 
					 | 
					 | 
					try:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    for epoch in range(epochs):
 | 
					 | 
					 | 
					 | 
					    for epoch in range(epochs):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
					 | 
					 | 
					 | 
					        test(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					             writer=writer )
 | 
					 | 
					 | 
					 | 
					            start_time=start_time, 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            epoch=epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            batches_per_epoch=batches_per_epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            batch_size=batch_size, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            model=model, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            optimizer=optimizer, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            credit_dataset=credit_train_dataset, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            test_auroc=test_auroc, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            writer=writer
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
 | 
					 | 
					 | 
					 | 
					        for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            model.train()
 | 
					 | 
					 | 
					 | 
					            model.train()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            optimizer.train()
 | 
					 | 
					 | 
					 | 
					            optimizer.train()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            optimizer.zero_grad()
 | 
					 | 
					 | 
					 | 
					            optimizer.zero_grad()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            cat_inputs, num_inputs, targets = [x.to(device_id, non_blocking=True) for x in [cat_inputs[0], num_inputs[0], targets[0]]]
 | 
					 | 
					 | 
					 | 
					            outputs, targets = model(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            outputs, targets = model(cat_inputs, num_inputs, targets)
 | 
					 | 
					 | 
					 | 
					                cat_inputs[0].to("cuda"), 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            loss = criterion(outputs, targets) 
 | 
					 | 
					 | 
					 | 
					                num_inputs[0].to("cuda"), 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                targets[0].to("cuda")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            loss = criterion(outputs, targets)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            loss.backward()
 | 
					 | 
					 | 
					 | 
					            loss.backward()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            optimizer.step()
 | 
					 | 
					 | 
					 | 
					            optimizer.step()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            ddp_loss[0] = loss.item()
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            torch.distributed.all_reduce(ddp_loss, op=torch.distributed.ReduceOp.SUM)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            ddp_loss[0] /= world_size
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            current_time = datetime.now()
 | 
					 | 
					 | 
					 | 
					            current_time = datetime.now()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            if current_time - last_display_time > timedelta(seconds=1):
 | 
					 | 
					 | 
					 | 
					            if current_time - last_display_time > timedelta(seconds=1):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                last_display_time = current_time
 | 
					 | 
					 | 
					 | 
					                last_display_time = current_time
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                if rank == 0:
 | 
					 | 
					 | 
					 | 
					                writer.add_scalar(f'Loss', loss.item(), epoch*batches_per_epoch+batch_id)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    writer.add_scalar(f'Loss', loss.item(), epoch*batches_per_epoch+batch_id)
 | 
					 | 
					 | 
					 | 
					                print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item():.6f} {comment}", end = " "*2)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item():.6f} {comment}", end = " "*2)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            if current_time - last_checkpoint_time > timedelta(hours=8): 
 | 
					 | 
					 | 
					 | 
					            if current_time - last_checkpoint_time > timedelta(hours=8): 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                last_checkpoint_time = current_time
 | 
					 | 
					 | 
					 | 
					                last_checkpoint_time = current_time
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
					 | 
					 | 
					 | 
					                test(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        writer=writer )
 | 
					 | 
					 | 
					 | 
					                    start_time=start_time, 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    epoch=epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    batches_per_epoch=batches_per_epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    batch_size=batch_size, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    model=model, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    optimizer=optimizer, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    credit_dataset=credit_train_dataset, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    test_auroc=test_auroc, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    writer=writer
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                rocauc = test_auroc.compute().item()
 | 
					 | 
					 | 
					 | 
					                rocauc = test_auroc.compute().item()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                save_checkpoint(
 | 
					 | 
					 | 
					 | 
					                save_checkpoint(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        credit_dataset=credit_train_dataset,
 | 
					 | 
					 | 
					 | 
					                    credit_dataset=credit_train_dataset,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
					 | 
					 | 
					 | 
					                    encoder = model.module.encoder, 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        loss=loss.item(), rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
					 | 
					 | 
					 | 
					                    model=model.module.classifier, 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    optimizer=optimizer, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    epoch=epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    loss=loss.item(), 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    rocauc=rocauc, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    сheсkpoints_dir=сheсkpoints_dir
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					except KeyboardInterrupt:
 | 
					 | 
					 | 
					 | 
					except KeyboardInterrupt:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    print()
 | 
					 | 
					 | 
					 | 
					    print()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					finally: 
 | 
					 | 
					 | 
					 | 
					finally: 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    test(epoch+1, batches_per_epoch, batch_size, model, optimizer, credit_train_dataset, test_auroc, 
 | 
					 | 
					 | 
					 | 
					    test(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					         writer=writer if rank==0 else None)
 | 
					 | 
					 | 
					 | 
					        start_time=start_time,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        epoch=epoch+1, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        batches_per_epoch=batches_per_epoch, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size=batch_size, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        model=model, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        optimizer=optimizer, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        credit_dataset=credit_train_dataset, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        test_auroc=test_auroc, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        writer=writer
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    rocauc = test_auroc.compute().item()
 | 
					 | 
					 | 
					 | 
					    rocauc = test_auroc.compute().item()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    save_checkpoint(
 | 
					 | 
					 | 
					 | 
					    save_checkpoint(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            credit_dataset=credit_train_dataset,
 | 
					 | 
					 | 
					 | 
					        credit_dataset=credit_train_dataset,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
					 | 
					 | 
					 | 
					        encoder = model.encoder, 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            loss=loss.item(), rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
					 | 
					 | 
					 | 
					        model=model.classifier, 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    writer.close()
 | 
					 | 
					 | 
					 | 
					        optimizer=optimizer, 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    torch.distributed.destroy_process_group()
 | 
					 | 
					 | 
					 | 
					        epoch=epoch, 
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        
 | 
					 | 
					 | 
					 | 
					        loss=loss.item(), 
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        rocauc=rocauc, 
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        сheсkpoints_dir=сheсkpoints_dir
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    writer.close()        
 |