|
|
@ -9,25 +9,23 @@ import cv2
|
|
|
|
from PIL import Image
|
|
|
|
from PIL import Image
|
|
|
|
from datetime import datetime
|
|
|
|
from datetime import datetime
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
class DemoOptions():
|
|
|
|
class ImageDemoOptions():
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
self.parser.add_argument('--net_model_path', '-n', type=str, default=None, help="Net model path folder")
|
|
|
|
self.parser.add_argument('--net_model_path', '-n', type=str, default="../../models/last_transfered_net.pth", help="Net model path folder")
|
|
|
|
self.parser.add_argument('--lut_model_path', '-l', type=str, default=None, help="Lut model path folder")
|
|
|
|
self.parser.add_argument('--lut_model_path', '-l', type=str, default="../../models/last_transfered_lut.pth", help="Lut model path folder")
|
|
|
|
self.parser.add_argument('--project_path', '-q', type=str, default="../../", help="Project path.")
|
|
|
|
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../../data/Set14/HR/monarch.png", help="HR image path")
|
|
|
|
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")
|
|
|
|
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../../data/Set14/LR/X4/monarch.png", help="LR image path")
|
|
|
|
|
|
|
|
self.parser.add_argument('--project_path', type=str, default="../../", help="Project path.")
|
|
|
|
|
|
|
|
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(self):
|
|
|
|
def parse_args(self):
|
|
|
|
args = self.parser.parse_args()
|
|
|
|
args = self.parser.parse_args()
|
|
|
|
args.project_path = Path(args.project_path).resolve()
|
|
|
|
args.project_path = Path(args.project_path).resolve()
|
|
|
|
if args.net_model_path is None:
|
|
|
|
args.hr_image_path = Path(args.hr_image_path).resolve()
|
|
|
|
args.project_path / "models" / "last_transfered_net.pth"
|
|
|
|
args.lr_image_path = Path(args.lr_image_path).resolve()
|
|
|
|
else:
|
|
|
|
args.net_model_path = Path(args.net_model_path).resolve()
|
|
|
|
args.net_model_path = Path(args.net_model_path).resolve()
|
|
|
|
args.lut_model_path = Path(args.lut_model_path).resolve()
|
|
|
|
if args.lut_model_path is None:
|
|
|
|
|
|
|
|
args.project_path / "models" / "last_transfered_lut.pth"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
args.lut_model_path = Path(args.lut_model_path).resolve()
|
|
|
|
|
|
|
|
return args
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def print_options(self, opt):
|
|
|
|
def print_options(self, opt):
|
|
|
@ -43,21 +41,10 @@ class DemoOptions():
|
|
|
|
print(message)
|
|
|
|
print(message)
|
|
|
|
print()
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
config_inst = DemoOptions()
|
|
|
|
config_inst = ImageDemoOptions()
|
|
|
|
config = config_inst.parse_args()
|
|
|
|
config = config_inst.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
start_script_time = datetime.now()
|
|
|
|
start_script_time = datetime.now()
|
|
|
|
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCNetCentered_10000.pth")
|
|
|
|
|
|
|
|
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_centered/checkpoints/RCLutCentered_0.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCNetRot90_7x7_10000.pth")
|
|
|
|
|
|
|
|
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_7x7/checkpoints/RCLutRot90_7x7_0.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCNetRot90_3x3_10000.pth")
|
|
|
|
|
|
|
|
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_rot90_3x3/checkpoints/RCLutRot90_3x3_0.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# net_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCNetx1_46000.pth")
|
|
|
|
|
|
|
|
# lut_model = LoadCheckpoint("/wd/luts/models/rcnet_x1/checkpoints/RCLutx1_0.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net_model = LoadCheckpoint(config.net_model_path).cuda()
|
|
|
|
net_model = LoadCheckpoint(config.net_model_path).cuda()
|
|
|
|
lut_model = LoadCheckpoint(config.lut_model_path).cuda()
|
|
|
|
lut_model = LoadCheckpoint(config.lut_model_path).cuda()
|
|
|
@ -65,11 +52,8 @@ lut_model = LoadCheckpoint(config.lut_model_path).cuda()
|
|
|
|
print(net_model)
|
|
|
|
print(net_model)
|
|
|
|
print(lut_model)
|
|
|
|
print(lut_model)
|
|
|
|
|
|
|
|
|
|
|
|
lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/lenna.png"))[:,:,::-1].copy()
|
|
|
|
lr_image = cv2.imread(str(config.project_path / "data" / "Set14/LR/X4/monarch.png"))[:,::-1,::-1].copy()
|
|
|
|
image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/lenna.png"))[:,:,::-1].copy()
|
|
|
|
image_gt = cv2.imread(str(config.project_path / "data" / "Set14/HR/monarch.png"))[:,::-1,::-1].copy()
|
|
|
|
|
|
|
|
|
|
|
|
# lr_image = cv2.imread(str(project_path / "data" / "Synthetic/LR/X4/linear.png"))[:,:,::-1].copy()
|
|
|
|
|
|
|
|
# image_gt = cv2.imread(str(project_path / "data" / "Synthetic/HR/linear.png"))[:,:,::-1].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
|
|
|
|
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
|
|
|
|
|
|
|
|
|
|
|
|