update image demo for multiple models

main
protsenkovi 8 months ago
parent 49c6e2c608
commit d33b65c235

@ -11,25 +11,23 @@ import argparse
class ImageDemoOptions(): class ImageDemoOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--net_model_path', '-n', type=str, default="../models/last_transfered_net.pth", help="Net model path folder") self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison")
self.parser.add_argument('--lut_model_path', '-l', type=str, default="../models/last_transfered_lut.pth", help="Lut model path folder")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--project_path', type=str, default="../", help="Project path.") self.parser.add_argument('--output_path', type=str, default="../models/", help="Project path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
self.parser.add_argument('--mirror', action='store_true', default=False) self.parser.add_argument('--mirror', action='store_true', default=False)
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
args.project_path = Path(args.project_path).resolve() args.output_path = Path(args.output_path).resolve()
args.hr_image_path = Path(args.hr_image_path).resolve() args.hr_image_path = Path(args.hr_image_path).resolve()
args.lr_image_path = Path(args.lr_image_path).resolve() args.lr_image_path = Path(args.lr_image_path).resolve()
args.net_model_path = Path(args.net_model_path).resolve() args.model_paths = [Path(x).resolve() for x in args.model_paths]
args.lut_model_path = Path(args.lut_model_path).resolve()
return args return args
def __repr__(self): def __repr__(self):
config = self.parser.parse_args() config = self.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()): for k, v in sorted(vars(config).items()):
@ -45,12 +43,10 @@ config_inst = ImageDemoOptions()
config = config_inst.parse_args() config = config_inst.parse_args()
start_script_time = datetime.now() start_script_time = datetime.now()
print(config_inst)
net_model = LoadCheckpoint(config.net_model_path).cuda() models = [LoadCheckpoint(x).cuda() for x in config.model_paths]
lut_model = LoadCheckpoint(config.lut_model_path).cuda() for m in models:
print(m)
print(net_model)
print(lut_model)
lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1] lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1]
image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1] image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1]
@ -62,15 +58,36 @@ image_gt = image_gt.copy()
input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda()
predictions = []
for model in models:
with torch.inference_mode(): with torch.inference_mode():
net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() prediction = model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy()
lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() predictions.append(prediction)
image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) images_predicted = []
image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) for model_path, model, prediction in zip(config.model_paths, models, predictions):
prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA)
images_predicted.append(prediction)
image_count = 1 + len(images_predicted)
t = np.sqrt(image_count).astype(np.int32)
residual = image_count % t
if residual != 0:
column_count = image_count
row_count = 1
else:
column_count = image_count // t
row_count = t
images = [image_gt] + images_predicted
Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(config.project_path / "models" / 'last_transfered_demo.png') columns = []
for i in range(row_count):
row = []
for j in range(column_count):
row.append(images[i*column_count + j])
columns.append(np.concatenate(row, axis=1))
canvas = np.concatenate(columns, axis=0).astype(np.uint8)
Image.fromarray(canvas).save(config.output_path / 'image_demo.png')
print(datetime.now() - start_script_time ) print(datetime.now() - start_script_time )
Loading…
Cancel
Save