You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			85 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			85 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
| import sys
 | |
| sys.path.insert(0, "../")  # run under the project directory
 | |
| import logging
 | |
| import math
 | |
| import os
 | |
| import time
 | |
| import numpy as np
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| import torch.optim as optim
 | |
| from PIL import Image
 | |
| from pathlib import Path
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| torch.backends.cudnn.benchmark = True
 | |
| from datetime import datetime
 | |
| import argparse
 | |
| import models
 | |
| 
 | |
| class TransferToLutOptions():
 | |
|     def __init__(self):
 | |
|         self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 | |
|         self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder")
 | |
|         self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")  
 | |
|         self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")  
 | |
| 
 | |
|     def parse_args(self): 
 | |
|         args = self.parser.parse_args()       
 | |
|         args.model_path = Path(args.model_path)
 | |
|         args.models_dir = Path(args.model_path).resolve().parent.parent.parent
 | |
|         args.checkpoint_dir = Path(args.model_path).resolve().parent
 | |
|         return args
 | |
| 
 | |
|     def __repr__(self):
 | |
|         config = self.parser.parse_args() 
 | |
|         message = ''
 | |
|         message += '----------------- Options ---------------\n'
 | |
|         for k, v in sorted(vars(config).items()):
 | |
|             comment = ''
 | |
|             default = self.parser.get_default(k)
 | |
|             if v != default:
 | |
|                 comment = '\t[default: %s]' % str(default)
 | |
|             message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
 | |
|         message += '----------------- End -------------------'
 | |
|         return message
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     start_time = datetime.now()
 | |
|     print(start_time)
 | |
|     config_inst = TransferToLutOptions()
 | |
|     config = config_inst.parse_args()
 | |
|     
 | |
|     print(config_inst)
 | |
|    
 | |
|     model = models.LoadCheckpoint(config.model_path).cuda()
 | |
|     if getattr(model, 'get_lut_model', None) is None:
 | |
|         print("Transfer to lut can be applied only to the network model.")
 | |
|         exit(1)
 | |
|     print(model)    
 | |
| 
 | |
|     print()
 | |
|     print("Transfering:")
 | |
|     lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
 | |
|     print()
 | |
|     print(lut_model)
 | |
|     
 | |
|     lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
 | |
|     models.SaveCheckpoint(model=lut_model, path=lut_path)
 | |
|     
 | |
|     lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
 | |
|     print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
 | |
| 
 | |
|     link = Path(config.models_dir / f"last_transfered_net.pth")
 | |
|     if link.exists():
 | |
|         link.unlink()
 | |
|         link.symlink_to(config.model_path.resolve())
 | |
|     link = Path(config.models_dir / f"last_transfered_lut.pth")
 | |
|     if link.exists():
 | |
|         link.unlink()
 | |
|         link.symlink_to(lut_path.resolve())
 | |
|     print("Updated link", config.models_dir / f"last_transfered_net.pth")
 | |
|     print("Updated link", config.models_dir / f"last_transfered_lut.pth")
 | |
| 
 | |
|     print()
 | |
|     print("Completed after", datetime.now()-start_time) |