You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
3.2 KiB
Python

import sys
sys.path.insert(0, "../") # run under the project directory
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark = True
from datetime import datetime
import argparse
import models
class TransferToLutOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
def parse_args(self):
args = self.parser.parse_args()
args.model_path = Path(args.model_path)
args.models_dir = Path(args.model_path).resolve().parent.parent.parent
args.checkpoint_dir = Path(args.model_path).resolve().parent
return args
def __repr__(self):
config = self.parser.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
return message
if __name__ == "__main__":
start_time = datetime.now()
print(start_time)
config_inst = TransferToLutOptions()
config = config_inst.parse_args()
print(config_inst)
model = models.LoadCheckpoint(config.model_path).cuda()
if getattr(model, 'get_lut_model', None) is None:
print("Transfer to lut can be applied only to the network model.")
exit(1)
print(model)
print()
print("Transfering:")
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
print()
print(lut_model)
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
models.SaveCheckpoint(model=lut_model, path=lut_path)
lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
link = Path(config.models_dir / f"last_transfered_net.pth")
if link.exists():
link.unlink()
link.symlink_to(config.model_path.resolve())
link = Path(config.models_dir / f"last_transfered_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(lut_path.resolve())
print("Updated link", config.models_dir / f"last_transfered_net.pth")
print("Updated link", config.models_dir / f"last_transfered_lut.pth")
print()
print("Completed after", datetime.now()-start_time)