import sys sys.path.insert(0, "../") # run under the project directory import logging import math import os import time import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from PIL import Image from pathlib import Path from torch.utils.tensorboard import SummaryWriter torch.backends.cudnn.benchmark = True from datetime import datetime import argparse import models class TransferToLutOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder") self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") def parse_args(self): args = self.parser.parse_args() args.model_path = Path(args.model_path) args.models_dir = Path(args.model_path).resolve().parent.parent.parent args.checkpoint_dir = Path(args.model_path).resolve().parent return args def __repr__(self): config = self.parser.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' return message if __name__ == "__main__": start_time = datetime.now() print(start_time) config_inst = TransferToLutOptions() config = config_inst.parse_args() print(config_inst) model = models.LoadCheckpoint(config.model_path).cuda() if getattr(model, 'get_lut_model', None) is None: print("Transfer to lut can be applied only to the network model.") exit(1) print(model) print() print("Transfering:") lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size) print() print(lut_model) lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth" models.SaveCheckpoint(model=lut_model, path=lut_path) lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()]) print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB") link = Path(config.models_dir / f"last_transfered_net.pth") if link.exists(): link.unlink() link.symlink_to(config.model_path.resolve()) link = Path(config.models_dir / f"last_transfered_lut.pth") if link.exists(): link.unlink() link.symlink_to(lut_path.resolve()) print("Updated link", config.models_dir / f"last_transfered_net.pth") print("Updated link", config.models_dir / f"last_transfered_lut.pth") print() print("Completed after", datetime.now()-start_time)