@ -18,17 +18,8 @@ import torch.nn.functional as F
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				DEVICE  =  " cuda " 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				comment  =  sys . argv [ 1 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				logs_dir  =  f ' runs/ { datetime . now ( ) . date ( ) } _ { datetime . now ( ) . hour : 02d } _ { datetime . now ( ) . minute : 02d } _ { datetime . now ( ) . second : 02d } _ { DEVICE_IDX } _ { comment } / ' 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				с  heс  kpoints_dir  =  f ' checkpoints/ { datetime . now ( ) . date ( ) } _ { datetime . now ( ) . hour : 02d } _ { datetime . now ( ) . minute : 02d } _ { datetime . now ( ) . second : 02d } _ { DEVICE_IDX } _ { comment } / ' 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				Path ( с  heс  kpoints_dir) . mkdir ( parents = True ,  exist_ok = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( " Logs dir: " ,  logs_dir ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( " Chekpoints dir: " ,  с  heс  kpoints_dir) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				writer  =  SummaryWriter ( logs_dir ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path  =  Path ( logs_dir  +  Path ( sys . argv [ 0 ] ) . name ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path . write_bytes ( Path ( sys . argv [ 0 ] ) . read_bytes ( ) )  # copy this version of script 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path . chmod ( 0o400 )  # with read-only permission 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  save_checkpoint ( credit_dataset ,  encoder ,  model ,  optimizer ,  epoch ,  loss ,  rocauc ,  с  heс  kpoints_dir) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  save_checkpoint ( credit_dataset ,  encoder ,  model ,  optimizer ,  epoch ,  loss ,  rocauc ,  checkpoints_dir ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    checkpoint  =  { 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ' encoder ' :  { 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ' state_dict ' :  encoder . state_dict ( ) , 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -47,28 +38,28 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ' train_uniq_client_ids_path ' :  credit_dataset . train_uniq_client_ids_path , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ' test_uniq_client_ids_path ' :  credit_dataset . test_uniq_client_ids_path 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    } 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    path  =  с  heс  kpoints_dir  +  f " epoch_ { epoch } _ { rocauc : .4f } .pth " 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    path  =  checkpoints_dir  +  f " epoch_ { epoch } _ { rocauc : .4f } .pth " 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # if torch.distributed.get_rank() == 0: 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    torch . save ( checkpoint ,  path ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    print ( f " \n Checkpoint saved to  { path } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				################################################################################################# 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  CreditProductsDataset : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        features_path ,  targets_path ,  train_test_split_ratio = 0.9 , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        train_uniq_client_ids_path = None ,  test_uniq_client_ids_path = None , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        dropout_rate = 0.0 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        train_uniq_client_ids_path = None ,  test_uniq_client_ids_path = None 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . __dict__ . update ( { k : v  for  k , v  in  locals ( ) . items ( )  if  k  !=  ' self ' } ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  Path ( self . train_uniq_client_ids_path ) . exists ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . train_uniq_client_ids  =  pd . read_csv ( self . train_uniq_client_ids_path ) . iloc [ : , 0 ] . values 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            print ( " Loaded " ,  self . train_uniq_client_ids_path ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        else :   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        else : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            raise  Exception ( f " No  { self . train_uniq_client_ids_path } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  Path ( self . test_uniq_client_ids_path ) . exists ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . test_uniq_client_ids  =  pd . read_csv ( self . test_uniq_client_ids_path ) . iloc [ : , 0 ] . values 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            print ( " Loaded " ,  self . test_uniq_client_ids_path ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        else :   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        else : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            raise  Exception ( f " No  { self . test_uniq_client_ids_path } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert ( len ( np . intersect1d ( self . train_uniq_client_ids ,  self . test_uniq_client_ids ) )  ==  0 ) ,  " Train contains test examples. " 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . features_df  =  pd . read_parquet ( features_path ) 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -88,7 +79,8 @@ class CreditProductsDataset:
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ' fclose_flag ' , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ' pre_loans5 ' ,  ' pre_loans6090 ' ,  ' pre_loans530 ' ,  ' pre_loans90 ' ,  ' pre_loans3060 ' 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_columns  =  [ ' pre_loans5 ' ]  # TODO empty list get DatParallel to crash 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_columns  =  [ ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cat_cardinalities  =  self . features_df . max ( axis = 0 ) [ self . cat_columns ]  +  1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cat_cardinalities_integral  =  self . cat_cardinalities . cumsum ( ) 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -97,8 +89,6 @@ class CreditProductsDataset:
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . features_df  =  self . features_df . sort_values ( self . id_columns ,  ascending = [ True ,  True ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . features_df  =  self . features_df . set_index ( ' id ' ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . targets_df  =  self . targets_df . set_index ( ' id ' ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . targets_df  =  self . targets_df . sort_index ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . user_seq_lengths  =  self . features_df . index . value_counts ( ) . sort_index ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -106,43 +96,80 @@ class CreditProductsDataset:
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cat_features  =  pad_sequence ( torch . split ( self . cat_features ,  self . user_seq_lengths . tolist ( ) ) ,  batch_first = True )  # implicit max seq 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_features  =  torch . tensor ( self . features_df [ self . num_columns ] . values ,  dtype = torch . float32 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_features  =  pad_sequence ( torch . split ( self . num_features ,  self . user_seq_lengths . tolist ( ) ) ,  batch_first = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . padding_mask  =  torch . ones ( len ( self . features_df ) ,  dtype = torch . bool ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . padding_mask  =  pad_sequence ( torch . split ( self . padding_mask ,  self . user_seq_lengths . tolist ( ) ) ,  batch_first = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . targets_df  =  self . targets_df . set_index ( ' id ' ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . targets_df  =  self . targets_df . sort_index ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . targets  =  torch . tensor ( self . targets_df . flag . values ) . type ( torch . float32 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_batch ( self ,  batch_size = 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_ train_ batch( self ,  batch_size = 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        sampled_ids  =  np . random . choice ( self . train_uniq_client_ids ,  batch_size ,  replace = False )  # think about replace=True 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        cat_features_batch  =  self . cat_features [ sampled_ids ]  *  torch . empty_like ( self . cat_features [ sampled_ids ] ) . bernoulli_ ( 1 - self . dropout_rate )  # argument is keep_prob 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        num_features_batch  =  self . num_features [ sampled_ids ]  *  torch . empty_like ( self . num_features [ sampled_ids ] ) . bernoulli_ ( 1 - self . dropout_rate )  # argument is keep_prob 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        targets_batch  =  self . targets [ sampled_ids ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  cat_features_batch ,  num_features_batch ,  targets_batch 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . cat_features [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . num_features [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . padding_mask [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . targets [ sampled_ids ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_test_batch_iterator ( self ,  batch_size = 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  i  in  range ( 0 ,  len ( self . test_uniq_client_ids ) ,  batch_size ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ids  =  self . test_uniq_client_ids [ i : i + batch_size ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_features_batch  =  self . cat_features [ ids ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            num_features_batch  =  self . num_features [ ids ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            targets_batch  =  self . targets [ ids ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            yield  cat_features_batch ,  num_features_batch ,  targets_batch 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            sampled_ids  =  self . test_uniq_client_ids [ i : i + batch_size ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            yield  ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                self . cat_features [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                self . num_features [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                self . padding_mask [ sampled_ids ] ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                self . targets [ sampled_ids ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# for parallel data selection 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  WrapperDataset ( Dataset ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  credit_dataset ,  batch_size ,  datasets_per_epoch = 1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . credit_dataset  =  credit_dataset 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . batch_size  =  batch_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_batches  =  len ( self . credit_dataset . train_uniq_client_ids )  \
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                            / /  self . batch_size  \
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                            *   datasets_per_epoch 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __len__ ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  self . num_batches 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __getitem__ ( self ,  idx ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        cat_inputs ,  num_inputs ,  padding_mask ,  targets  =  self . credit_dataset . get_train_batch ( batch_size = self . batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  cat_inputs ,  num_inputs ,  padding_mask ,  targets 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				##################################### Model ########################################################################################### 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  Encoder ( nn . Module ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  cat_columns ,  num_columns ,  cat_features_max_id ,  category_feature_dim = 4 ,  out_dim = 64 ,  dropout_rate = 0.5 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  cat_columns ,  num_columns ,  cat_features_max_id ,  category_feature_dim = 4 ,  out_dim = 64 ,  features_dropout_rate= 0.0  ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . __dict__ . update ( { k : v  for  k , v  in  locals ( ) . items ( )  if  k  !=  ' self ' } ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . __dict__ . update ( { k : v  for  k , v  in  locals ( ) . items ( )  if  k  !=  ' self ' } )  # all args are added as object variables  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . total_h_dim  =  len ( self . cat_columns )  *  category_feature_dim  +  len ( self . num_columns ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cat_embeds  =  nn . Embedding ( cat_features_max_id  +  1 ,  self . category_feature_dim ,  padding_idx = 0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_scales  =  nn . Parameter ( torch . randn ( 1 ,  len ( self . num_columns ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . num_shifts  =  nn . Parameter ( torch . randn ( 1 ,  len ( self . num_columns ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  len ( self . cat_columns )  >  0 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . cat_embeds  =  nn . Embedding ( cat_features_max_id  +  1 ,  self . category_feature_dim ,  padding_idx = 0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  len ( self . num_columns )  >  0 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            # in case == 0 script crashes during backprob without if in dataparallel mode 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . num_scales  =  nn . Parameter ( torch . randn ( 1 ,  len ( self . num_columns ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . num_shifts  =  nn . Parameter ( torch . randn ( 1 ,  len ( self . num_columns ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . proj  =  nn . Linear ( self . total_h_dim ,  self . out_dim ,  bias = False ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  cat_features_batch ,  num_features_batch ,  targets_batch ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        cat_embed_tensor  =  self . cat_embeds ( cat_features_batch . type ( torch . int32 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        cat_embed_tensor  =  cat_embed_tensor . reshape ( cat_features_batch . shape [ 0 ] ,  cat_features_batch . shape [ 1 ] ,  - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        num_embed_tensor  =  self . num_scales  *  num_features_batch  +  self . num_shifts 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        embed_tensor  =  torch . concat ( [ cat_embed_tensor ,  num_embed_tensor ] ,  dim = - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  cat_features_batch ,  num_features_batch ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  len ( self . cat_columns )  >  0  and  len ( self . num_columns )  >  0 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_embed_tensor  =  self . cat_embeds ( cat_features_batch . data . type ( torch . int32 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_embed_tensor  =  cat_embed_tensor . reshape ( cat_features_batch . data . shape [ 0 ] ,  cat_features_batch . data . shape [ 1 ] ,  - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            num_embed_tensor  =  self . num_scales  *  num_features_batch . data  +  self . num_shifts 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            embed_tensor  =  torch . concat ( [ cat_embed_tensor . data ,  num_embed_tensor . data ] ,  dim = - 1 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        elif  len ( self . cat_columns )  ==  0 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            embed_tensor  =  self . num_scales  *  num_features_batch . data  +  self . num_shifts 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        elif  len ( self . num_columns )  ==  0 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_embed_tensor  =  self . cat_embeds ( cat_features_batch . data . type ( torch . int32 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_embed_tensor  =  cat_embed_tensor . reshape ( cat_features_batch . data . shape [ 0 ] ,  cat_features_batch . data . shape [ 1 ] ,  - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            embed_tensor  =  cat_embed_tensor . data 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        else : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            raise  Exception ( " The batch is empty. " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        embed_tensor  =  F . dropout ( embed_tensor ,  self . features_dropout_rate ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        inputs  =  self . proj ( embed_tensor ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        targets  =  targets_batch 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  inputs ,  targets 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  inputs 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  RoPE ( nn . Module ) : 
 
			
		 
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
			
			 
			 
			
				@ -171,7 +198,7 @@ class DyT(nn.Module):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . alpha  =  nn . Parameter ( torch . ones ( 1 )  *  alpha_init_value ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . weight  =  nn . Parameter ( torch . ones ( num_features ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . bias  =  nn . Parameter ( torch . zeros ( num_features ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . tanh ( self . alpha  *  x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x  *  self . weight  +  self . bias 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -188,27 +215,31 @@ class TransformerLayer(nn.Module):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . o_proj  =  nn . Linear ( h_dim ,  h_dim ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ff1  =  nn . Linear ( h_dim ,  4 * h_dim ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ff2  =  nn . Linear ( 4 * h_dim ,  h_dim ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ln1  =  DyT ( h_dim )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ln2  =  DyT ( h_dim )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ln3  =  DyT ( max_seq_len )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ln1  =  DyT ( h_dim ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . ln2  =  DyT ( h_dim ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . rope  =  RoPE ( dim = h_dim / / self . num_heads ,  max_seq_len = max_seq_len ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  split_to_heads ( self ,  x ,  B ,  T ,  H ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  rearrange ( x ,  ' b t (n h) -> (b n) t h ' ,  b = B ,  t = T ,  n = self . num_heads )  if  self . num_heads  >  1  else  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  self . num_heads  < =  1 :  return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  rearrange ( x ,  ' b t (n h) -> (b n) t h ' ,  b = B ,  t = T ,  n = self . num_heads ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  gather_heads ( self ,  x ,  B ,  T ,  H ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  rearrange ( x ,  ' (b n) t h -> b t (n h) ' ,  b = B ,  t = T ,  n = self . num_heads )  if  self . num_heads  >  1  else  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  self . num_heads  < =  1 :  return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  rearrange ( x ,  ' (b n) t h -> b t (n h) ' ,  b = B ,  t = T ,  n = self . num_heads ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  attention ( self ,  x ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  attention ( self ,  x ,  padding_mask ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        padding_mask  =  padding_mask . unsqueeze ( - 1 ) . expand ( * padding_mask . shape + ( self . num_heads , ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        padding_mask  =  self . split_to_heads ( padding_mask ,  * padding_mask . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        q  =  self . rope ( self . split_to_heads ( self . q_proj ( x ) ,  * x . shape ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        k  =  self . rope ( self . split_to_heads ( self . k_proj ( x ) ,  * x . shape ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        v  =  self . split_to_heads ( self . v_proj ( x ) ,  * x . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        scores  =  ( q  @  k . transpose ( 1 ,  2 ) )  *  ( self . h_dim  * *  - 0.5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        attention  =  self . ln3 ( F . dropout1d ( scores ,  p = self . dropout_rate ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        scores  =  scores . masked_fill ( ~ padding_mask ,  - 1e9 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        attention  =  nn . functional . softmax ( scores ,  dim = 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  self . o_proj ( self . gather_heads ( attention  @  v ,  * x . shape ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x  +  F . dropout1d ( self . attention ( self . ln1 ( x )  ),  p = self . dropout_rate ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x , padding_mask  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x  +  F . dropout1d ( self . attention ( self . ln1 ( x ) , padding_mask  ),  p = self . dropout_rate ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x  +  F . dropout1d ( self . ff2 ( F . gelu ( self . ff1 ( self . ln2 ( x ) ) ) ) ,  p = self . dropout_rate ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -216,17 +247,18 @@ class BertClassifier(nn.Module):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  layers_num = 1 ,  h_dim = 64 ,  class_num = 2 ,  max_seq_len = 128 ,  num_heads = 4 ,  dropout_rate  =  0.1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . __dict__ . update ( { k : v  for  k , v  in  locals ( ) . items ( )  if  k  !=  ' self ' } ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cls_token  =  nn . Parameter ( torch . randn ( 1 , 1 , h_dim ) )   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . cls_token  =  nn . Parameter ( torch . randn ( 1 , 1 , h_dim ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . max_seq_len  =  max_seq_len  +  self . cls_token . shape [ 1 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . layers  =  nn . ModuleList ( [ TransformerLayer ( h_dim = h_dim ,  num_heads = num_heads ,  dropout_rate  =  dropout_rate ,  max_seq_len = self . max_seq_len )  for  _  in  range ( layers_num ) ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . classifier_head  =  nn . Sequential ( nn . Linear ( h_dim ,  h_dim ) ,  nn . Dropout( 0.1 ) ,  nn .  GELU( ) ,  nn . Linear ( h_dim ,  class_num ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . classifier_head  =  nn . Sequential ( nn . Linear ( h_dim ,  h_dim ) ,  nn .  GELU( ) ,  nn . Linear ( h_dim ,  class_num ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . pos_embeds  =  nn . Parameter ( torch . randn ( 1 ,  self . max_seq_len ,  h_dim ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . concat ( [ self . cls_token . expand ( [ x . shape [ 0 ] ,  self . cls_token . shape [ 1 ] ,  self . cls_token . shape [ 2 ] ] ) ,  x ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  padding_mask ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . cat ( [ self . cls_token . expand ( [ x . shape [ 0 ] ,  self . cls_token . shape [ 1 ] ,  x . shape [ 2 ] ] ) ,  x ] ,  dim = 1 )  # prepend 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        padding_mask  =  torch . cat ( [ torch . ones ( x . shape [ 0 ] ,  1 ,  dtype = torch . bool ,  device = x . device ) ,  padding_mask ] ,  dim = 1 )  # prepend 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x  +  self . pos_embeds [ : ,  : x . shape [ 1 ] ,  : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  l  in  self . layers : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x  =  l ( x  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x  =  l ( x , padding_mask  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  self . classifier_head ( x [ : , 0 , : ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x [ : , : ]  if  self . class_num  >  1  else  x [ : , 0 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -235,56 +267,83 @@ class Model(nn.Module):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . encoder  =  encoder 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . classifier  =  classifier 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  cat_inputs ,  num_inputs ,  targets ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        inputs ,  targets  =  self . encoder ( cat_inputs ,  num_inputs ,  targets ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  self . classifier ( inputs ) ,  targets 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  cat_inputs ,  num_inputs ,  padding_mask ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        inputs  =  self . encoder ( cat_inputs ,  num_inputs ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  self . classifier ( inputs ,  padding_mask ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  test ( start_time ,  epoch ,  batches_per_epoch ,  batch_size ,  model ,  optimizer ,  credit_dataset ,  test_auroc ,  writer ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        model . eval ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        optimizer . eval ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        with  torch . no_grad ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            test_iterator  =  credit_dataset . get_test_batch_iterator ( batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            for  test_batch_id ,  ( test_cat_inputs ,  test_num_inputs ,  test_padding_mask ,  test_targets )  in  enumerate ( test_iterator ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test_cat_inputs  =  test_cat_inputs . to ( " cuda " ,  non_blocking = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test_num_inputs  =  test_num_inputs . to ( " cuda " ,  non_blocking = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test_padding_mask  =  test_padding_mask . to ( " cuda " ,  non_blocking = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test_targets  =  test_targets . to ( " cuda " ,  non_blocking = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                outputs  =  model ( test_cat_inputs ,  test_num_inputs ,  test_padding_mask ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test_auroc . update ( outputs ,  test_targets . long ( ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                print ( f " \r   { test_batch_id } / { len ( credit_dataset . test_uniq_client_ids ) / / batch_size }   { test_auroc . compute ( ) . item ( ) : .5f } " ,  end  =  "   " * 20 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        if  not  writer  is  None : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            writer . add_scalar ( ' test_roc_auc ' ,  test_auroc . compute ( ) . item ( ) ,  epoch  *  batches_per_epoch ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        print ( f " \r   { datetime . now ( )  -  start_time }   { epoch } / { epochs }  Test rocauc:  { test_auroc . compute ( ) . item ( ) : .5f } " ,  end  =  "   " * 20 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        print ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				######################################### Training ################################################################ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				h_dim  =  64 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				category_feature_dim  =  8 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				layers_num  =  6 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				num_heads  =  2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class_num  =  1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				dropout_rate  =  0.4 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				epochs  =  800 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				batch_size  =  30000 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				features_dropout_rate  =  0.4 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				model_dropout_date  =  0.4 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				epochs  =  500 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				batch_size  =  30000 * len ( DEVICE_IDX . split ( ' , ' ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				datasets_per_epoch  =  len ( DEVICE_IDX . split ( ' , ' ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				num_workers  =  10 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				logs_dir  =  f ' logs/ { datetime . now ( ) . date ( ) } _ { datetime . now ( ) . hour : 02d } _ { datetime . now ( ) . minute : 02d } _ { datetime . now ( ) . second : 02d } _ { comment } / ' 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				writer  =  SummaryWriter ( logs_dir ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				checkpoints_dir  =  f ' checkpoints/ { datetime . now ( ) . date ( ) } _ { datetime . now ( ) . hour : 02d } _ { datetime . now ( ) . minute : 02d } _ { datetime . now ( ) . second : 02d } _ { comment } / ' 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path  =  Path ( logs_dir  +  Path ( sys . argv [ 0 ] ) . name ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				Path ( checkpoints_dir ) . mkdir ( parents = True ,  exist_ok = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( " Logs dir: " ,  logs_dir ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( " Chekpoints dir: " ,  checkpoints_dir ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path . write_bytes ( Path ( sys . argv [ 0 ] ) . read_bytes ( ) )  # copy this version of script 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				script_snapshot_path . chmod ( 0o400 )  # with read-only permission 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				start_prep_time  =  datetime . now ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				credit_train_dataset  =  CreditProductsDataset ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    features_path = " /wd/data/train_data/ " ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    targets_path = " /wd/data/train_target.csv " , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # train_uniq_client_ids_path="/wd/train_uniq_client_ids.csv",  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # test_uniq_client_ids_path="/wd/test_uniq_client_ids.csv", 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # train_uniq_client_ids_path="/wd/dima_train_ids.csv",  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # test_uniq_client_ids_path="/wd/dima_test_ids.csv", 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # train_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_train_ids.csv", 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # test_uniq_client_ids_path=f"/wd/fold{DEVICE_IDX}_test_ids.csv", 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    train_uniq_client_ids_path = f " /wd/fold3_train_ids.csv " , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    test_uniq_client_ids_path = f " /wd/fold3_test_ids.csv " , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dropout_rate = dropout_rate 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				batches_per_epoch  =  len ( credit_train_dataset . uniq_client_ids )  / /  batch_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( f " Dataset preparation time:  { datetime . now ( )  -  start_prep_time } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				encoder  =  Encoder ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    cat_columns = credit_train_dataset . cat_columns , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    num_columns = credit_train_dataset . num_columns ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    num_columns = credit_train_dataset . num_columns , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    cat_features_max_id = credit_train_dataset . cat_features . max ( ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    category_feature_dim = category_feature_dim ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    category_feature_dim = category_feature_dim , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out_dim = h_dim , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				     dropout_rate=  dropout_rate
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    features_ dropout_rate= features_ dropout_rate
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				classifier  =  BertClassifier ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    layers_num = layers_num ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    layers_num = layers_num , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    num_heads = num_heads , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    h_dim = h_dim ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    class_num = class_num ,   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    h_dim = h_dim , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    class_num = class_num , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    max_seq_len = credit_train_dataset . max_user_history , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dropout_rate  =  dropout_r ate
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dropout_rate  =  model_dropout_d ate
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				model  =  Model ( encoder = encoder ,  classifier = classifier ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( f " Model parameters count:  " ,  sum ( p . numel ( )  for  p  in  model . parameters ( ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				model  =  torch . nn . DataParallel ( model ,  device_ids = [ int ( idx )  for  idx  in  DEVICE_IDX . split ( " , " ) ] ) . to ( DEVICE ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# The Road Less Scheduled https://arxiv.org/html/2405.15682v4 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -294,76 +353,99 @@ positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				negative_counts  =  len ( credit_train_dataset . targets_df . loc [ credit_train_dataset . train_uniq_client_ids ] )  -  positive_counts 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				pos_weight  =  negative_counts  /  ( positive_counts  +  1e-15 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( f " Class imbalance:  { negative_counts }   { positive_counts } . Pos weight:  { pos_weight } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				criterion  =  torch . nn . BCEWithLogitsLoss ( pos_weight = torch . tensor ( pos_weight ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# for parallel data selection  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  WrapperDataset ( Dataset ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  credit_dataset ,  encoder ,  batch_size ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . credit_dataset  =  credit_dataset 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . encoder  =  encoder 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . batch_size  =  batch_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __len__ ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  len ( self . credit_dataset . uniq_client_ids )  / /  self . batch_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				criterion  =  torch . nn . BCEWithLogitsLoss ( pos_weight = torch . tensor ( pos_weight ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __getitem__ ( self ,  idx ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        cat_inputs ,  num_inputs ,  targets  =  credit_train_dataset . get_batch ( batch_size = self . batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  cat_inputs ,  num_inputs ,  targets 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				training_data  =  WrapperDataset ( credit_dataset = credit_train_dataset ,  batch_size = batch_size ,  datasets_per_epoch = datasets_per_epoch ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				dataloader  =  DataLoader ( training_data ,  batch_size = 1 ,  shuffle = False ,  num_workers = num_workers ,  pin_memory = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				training_data =  WrapperDataset ( credit_train_dataset ,  encoder ,  batch_size = batch_size  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				dataloader  =  DataLoader ( training_data ,  batch_size = 1 ,  shuffle = False ,  num_workers = 8 * 2 ,  pin_memory = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				batches_per_epoch  =  len ( training_data ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( f " Number of batches per epoch:  { batches_per_epoch } , Number of datasets per epoch :  { datasets_per_epoch } " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				test_auroc  =  AUROC ( task = ' binary ' ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  test ( epoch ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    model . eval ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    optimizer . eval ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    with  torch . no_grad ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        test_iterator  =  credit_train_dataset . get_test_batch_iterator ( batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  test_batch_id ,  ( test_cat_inputs ,  test_num_inputs ,  test_targets )  in  enumerate ( test_iterator ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            test_cat_inputs ,  test_num_inputs ,  test_targets  =  [ x . to ( " cuda " ,  non_blocking = True )  for  x  in  [ test_cat_inputs ,  test_num_inputs ,  test_targets ] ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            outputs ,  targets  =  model ( test_cat_inputs ,  test_num_inputs ,  test_targets ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            test_auroc . update ( outputs ,  targets . long ( ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            print ( f " \r   { test_batch_id } / { len ( credit_train_dataset . test_uniq_client_ids ) / / batch_size }   { test_auroc . compute ( ) . item ( ) : .5f } " ,  end  =  "   " * 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    writer . add_scalar ( ' test_roc_auc ' ,  test_auroc . compute ( ) . item ( ) ,  epoch  *  batches_per_epoch ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    print ( f " \r   { datetime . now ( )  -  start_time }   { epoch } / { epochs }  Test rocauc:  { test_auroc . compute ( ) . item ( ) : .5f } " ,  end  =  "   " * 2 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    print ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				start_time  =  datetime . now ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				print ( " Started at: " ,  start_time ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				last_display_time  =  start_time 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				last_checkpoint_time  =  start_time 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				try : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    for  epoch  in  range ( epochs ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        test ( epoch ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  batch_id ,  ( cat_inputs ,  num_inputs ,  targets )  in  enumerate ( dataloader ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        test ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            start_time = start_time , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            epoch = epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            batches_per_epoch = batches_per_epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            batch_size = batch_size , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            model = model , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            optimizer = optimizer , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            credit_dataset = credit_train_dataset , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            test_auroc = test_auroc , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            writer = writer 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  batch_id ,  ( cat_inputs ,  num_inputs ,  padding_mask ,  targets )  in  enumerate ( dataloader ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            model . train ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            optimizer . train ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            optimizer . zero_grad ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            cat_inputs ,  num_inputs ,  targets  =  [ x . to ( " cuda " ,  non_blocking = True )  for  x  in  [ cat_inputs [ 0 ] ,  num_inputs [ 0 ] ,  targets [ 0 ] ] ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            outputs ,  targets  =  model ( cat_inputs ,  num_inputs ,  targets ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            loss  =  criterion ( outputs ,  targets )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            outputs  =  model ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                cat_inputs [ 0 ] . to ( " cuda " ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                num_inputs [ 0 ] . to ( " cuda " ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                padding_mask [ 0 ] . to ( " cuda " ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            loss  =  criterion ( outputs ,  targets [ 0 ] . to ( " cuda " ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            loss . backward ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            optimizer . step ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            current_time  =  datetime . now ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  current_time  -  last_display_time  >  timedelta ( seconds = 1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                last_display_time  =  current_time 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                writer . add_scalar ( ' Loss ' ,  loss . item ( ) ,  epoch * batches_per_epoch + batch_id ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                writer . add_scalar ( f ' Loss ' ,  loss . item ( ) ,  epoch * batches_per_epoch + batch_id ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                print ( f " \r   { current_time - start_time }   { epoch + 1 } / { epochs }   { batch_id } / { batches_per_epoch }  loss:  { loss . item ( ) : .6f }   { comment } " ,  end  =  "   " * 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  current_time  -  last_checkpoint_time  >  timedelta ( hours = 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  current_time  -  last_checkpoint_time  >  timedelta ( hours = 8 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                last_checkpoint_time  =  current_time 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test ( epoch ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                test ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    start_time = start_time , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    epoch = epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    batches_per_epoch = batches_per_epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    batch_size = batch_size , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    model = model , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    optimizer = optimizer , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    credit_dataset = credit_train_dataset , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    test_auroc = test_auroc , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    writer = None 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                rocauc  =  test_auroc . compute ( ) . item ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                save_checkpoint ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    credit_dataset = credit_train_dataset , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    encoder  =  encoder ,  model = model ,  optimizer = optimizer ,  epoch = epoch ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    loss = loss . item ( ) ,  rocauc = test_auroc . compute ( ) . item ( ) ,  с  heс  kpoints_dir= с  heс  kpoints_dir) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    encoder  =  model . module . encoder , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    model = model . module . classifier , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    optimizer = optimizer , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    epoch = epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    loss = loss . item ( ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    rocauc = rocauc , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    checkpoints_dir = checkpoints_dir 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				except  KeyboardInterrupt : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    print ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				finally :  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    test ( epoch + 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				finally : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    test ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        start_time = start_time , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        epoch = epoch + 1 , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        batches_per_epoch = batches_per_epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        batch_size = batch_size , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        model = model , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        optimizer = optimizer , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        credit_dataset = credit_train_dataset , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        test_auroc = test_auroc , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        writer = writer 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    rocauc  =  test_auroc . compute ( ) . item ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    save_checkpoint ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        credit_dataset = credit_train_dataset , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        encoder  =  encoder ,  model = model ,  optimizer = optimizer ,  epoch = epoch + 1 ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        loss = loss . item ( ) ,  rocauc = test_auroc . compute ( ) . item ( ) ,  с  heс  kpoints_dir= с  heс  kpoints_dir) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    writer . close ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        encoder  =  model . module . encoder , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        model = model . module . classifier , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        optimizer = optimizer , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        epoch = epoch , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        loss = loss . item ( ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        rocauc = rocauc , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        checkpoints_dir = checkpoints_dir 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    writer . close ( )