You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
import sys
|
|
sys.path.insert(0, "../") # run under the project directory
|
|
import logging
|
|
import math
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from PIL import Image
|
|
from pathlib import Path
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
torch.backends.cudnn.benchmark = True
|
|
from datetime import datetime
|
|
import argparse
|
|
import models
|
|
|
|
class TransferToLutOptions():
|
|
def __init__(self):
|
|
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder")
|
|
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
|
|
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
|
|
|
|
def parse_args(self):
|
|
args = self.parser.parse_args()
|
|
args.model_path = Path(args.model_path)
|
|
args.models_dir = Path(args.model_path).resolve().parent.parent.parent
|
|
args.checkpoint_dir = Path(args.model_path).resolve().parent
|
|
return args
|
|
|
|
def __repr__(self):
|
|
config = self.parser.parse_args()
|
|
message = ''
|
|
message += '----------------- Options ---------------\n'
|
|
for k, v in sorted(vars(config).items()):
|
|
comment = ''
|
|
default = self.parser.get_default(k)
|
|
if v != default:
|
|
comment = '\t[default: %s]' % str(default)
|
|
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
|
message += '----------------- End -------------------'
|
|
return message
|
|
|
|
|
|
if __name__ == "__main__":
|
|
start_time = datetime.now()
|
|
print(start_time)
|
|
config_inst = TransferToLutOptions()
|
|
config = config_inst.parse_args()
|
|
|
|
print(config_inst)
|
|
|
|
model = models.LoadCheckpoint(config.model_path).cuda()
|
|
if getattr(model, 'get_lut_model', None) is None:
|
|
print("Transfer to lut can be applied only to the network model.")
|
|
exit(1)
|
|
print(model)
|
|
|
|
print()
|
|
print("Transfering:")
|
|
lut_model = model.get_lut_model(quantization_interval=2**(8-config.quantization_bits), batch_size=config.batch_size)
|
|
print()
|
|
print(lut_model)
|
|
|
|
lut_path = Path(config.checkpoint_dir) / f"{lut_model.__class__.__name__}_0.pth"
|
|
models.SaveCheckpoint(model=lut_model, path=lut_path)
|
|
|
|
lut_model_size = np.sum([x.nelement()*x.element_size() for x in lut_model.parameters()])
|
|
print("Saved to", lut_path, f"{lut_model_size/(2**20):.3f} MB")
|
|
|
|
link = Path(config.models_dir / f"last_transfered_net.pth")
|
|
if link.exists():
|
|
link.unlink()
|
|
link.symlink_to(config.model_path.resolve())
|
|
link = Path(config.models_dir / f"last_transfered_lut.pth")
|
|
if link.exists():
|
|
link.unlink()
|
|
link.symlink_to(lut_path.resolve())
|
|
print("Updated link", config.models_dir / f"last_transfered_net.pth")
|
|
print("Updated link", config.models_dir / f"last_transfered_lut.pth")
|
|
|
|
print()
|
|
print("Completed after", datetime.now()-start_time) |