| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -26,7 +26,7 @@ from torch.distributed.fsdp.wrap import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import functools
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    checkpoint = {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'encoder': {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            'state_dict': encoder.state_dict(),
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -45,7 +45,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    # if torch.distributed.get_rank() == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    torch.save(checkpoint, path)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    print(f"\nCheckpoint saved to {path}")
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -330,11 +330,11 @@ if __name__ == "__main__":
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        comment = sys.argv[1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        writer = SummaryWriter(logs_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print("Logs dir:", logs_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print("Chekpoints dir:", сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print("Chekpoints dir:", checkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        script_snapshot_path.chmod(0o400) # with read-only permission
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -432,7 +432,7 @@ if __name__ == "__main__":
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    torch.distributed.barrier()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    except KeyboardInterrupt:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        print()
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -444,7 +444,7 @@ if __name__ == "__main__":
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            save_checkpoint(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                credit_dataset=credit_train_dataset,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if rank == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            writer.close()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        torch.distributed.destroy_process_group()
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |