| 
						
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -1,7 +1,5 @@
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import os
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import sys
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# DEVICE_IDX = sys.argv[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# os.environ['CUDA_VISIBLE_DEVICES'] = f"{DEVICE_IDX}"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				comment = sys.argv[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from torch import nn
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -21,31 +19,7 @@ import torch.nn.functional as F
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from typing import Optional
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import functools
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    checkpoint = {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'encoder': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': encoder.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'model': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': model.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'optimizer': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': optimizer.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'epoch': epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'loss': loss,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'rocauc': rocauc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # if torch.distributed.get_rank() == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.save(checkpoint, path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(f"\nCheckpoint saved to {path}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################## Dataset #########################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################## Dataset definition #########################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class CreditProductsDataset:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self,
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -140,7 +114,7 @@ class WrapperDataset(Dataset):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        cat_inputs, num_inputs, padding_mask, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return cat_inputs, num_inputs, padding_mask, targets
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				##################################### Model ###########################################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				##################################### Model definition ###########################################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class Encoder(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, features_dropout_rate=0.0):
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -230,7 +204,7 @@ class TransformerLayer(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def attention(self, x, padding_mask):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,)) # B, T -> B, T, num_heads
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        padding_mask = self.split_to_heads(padding_mask, *padding_mask.shape)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -251,7 +225,7 @@ class BertClassifier(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cls_token = nn.Parameter(torch.randn(1,1,h_dim))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.max_seq_len = max_seq_len + self.cls_token.shape[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate=dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -274,25 +248,7 @@ class Model(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        inputs = self.encoder(cat_inputs, num_inputs)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return self.classifier(inputs, padding_mask)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_dataset, test_auroc, writer):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        model.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        optimizer.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        with torch.no_grad():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            test_iterator = credit_dataset.get_test_batch_iterator(batch_size=batch_size)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_num_inputs = test_num_inputs.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_padding_mask = test_padding_mask.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_targets = test_targets.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_auroc.update(outputs, test_targets.long())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if not writer is None:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################### Training ################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				######################################### Training definition ################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				h_dim = 64
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				category_feature_dim = 8
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -306,7 +262,6 @@ batch_size = 30000
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				datasets_per_epoch = 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				num_workers = 10
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				writer = SummaryWriter(logs_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -365,34 +320,74 @@ batches_per_epoch = len(training_data)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				print(f"Number of batches per epoch: {batches_per_epoch}, Number of datasets per epoch : {datasets_per_epoch}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				test_auroc = AUROC(task='binary')
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				epoch = -1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				loss = torch.tensor([-1])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				start_time = datetime.now()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				print("Started at:", start_time)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				last_display_time = start_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				last_checkpoint_time = start_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def test(is_tensorboard_logging=False):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        model.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        optimizer.eval()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        with torch.no_grad():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_num_inputs = test_num_inputs.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_padding_mask = test_padding_mask.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_targets = test_targets.to("cuda", non_blocking=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test_auroc.update(outputs, test_targets.long())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if is_tensorboard_logging:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def save_checkpoint():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    test(is_tensorboard_logging=False)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    checkpoint = {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'encoder': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': encoder.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'model': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': model.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'optimizer': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': optimizer.state_dict(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'epoch': epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'loss': loss.item(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'rocauc': test_auroc.compute().item(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'train_uniq_client_ids_path': credit_train_dataset.train_uniq_client_ids_path,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'test_uniq_client_ids_path': credit_train_dataset.test_uniq_client_ids_path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    path = checkpoints_dir + f"epoch_{epoch}_{test_auroc.compute().item():.4f}.pth"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.save(checkpoint, path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(f"\nCheckpoint saved to {path}")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				###################################### Training loop ################################################################################
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				try:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    for epoch in range(epochs):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            start_time=start_time,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            epoch=epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            batches_per_epoch=batches_per_epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            batch_size=batch_size,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            model=model,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            optimizer=optimizer,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            test_auroc=test_auroc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer=writer
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test(is_tensorboard_logging=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for batch_id, (cat_inputs, num_inputs, padding_mask, targets) in enumerate(dataloader):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            model.train()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            optimizer.train()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            optimizer.zero_grad()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            outputs = model(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                cat_inputs[0].to("cuda"),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                num_inputs[0].to("cuda"),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                padding_mask[0].to("cuda")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                cat_inputs=cat_inputs[0].to("cuda"),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                num_inputs=num_inputs[0].to("cuda"),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                padding_mask=padding_mask[0].to("cuda")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss = criterion(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                input=outputs, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                target=targets[0].to("cuda")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss = criterion(outputs, targets[0].to("cuda"))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss.backward()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            optimizer.step()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -401,53 +396,13 @@ try:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                last_display_time = current_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                writer.add_scalar(f'Loss', loss.item(), epoch*batches_per_epoch+batch_id)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item():.6f} {comment}", end = " "*2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            if current_time - last_checkpoint_time > timedelta(hours=8):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                last_checkpoint_time = current_time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                test(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    start_time=start_time,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    epoch=epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    batches_per_epoch=batches_per_epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    batch_size=batch_size,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    model=model,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    optimizer=optimizer,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    test_auroc=test_auroc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    writer=None
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                rocauc = test_auroc.compute().item()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    encoder = model.module.encoder,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    model=model.module.classifier,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    optimizer=optimizer,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    epoch=epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    loss=loss.item(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    rocauc=rocauc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    checkpoints_dir=checkpoints_dir
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                save_checkpoint()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				except KeyboardInterrupt:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				finally:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    test(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        start_time=start_time,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        epoch=epoch+1,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        batches_per_epoch=batches_per_epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        batch_size=batch_size,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        model=model,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        optimizer=optimizer,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test_auroc=test_auroc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        writer=writer
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    rocauc = test_auroc.compute().item()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        encoder = model.encoder,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        model=model.classifier,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        optimizer=optimizer,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        epoch=epoch,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        loss=loss.item(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        rocauc=rocauc,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        checkpoints_dir=checkpoints_dir
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    writer.close()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    epoch = epoch + 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    save_checkpoint()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    writer.close()
 |