diff --git a/src/image_demo.py b/src/image_demo.py index ba99934..6c993e6 100644 --- a/src/image_demo.py +++ b/src/image_demo.py @@ -11,25 +11,23 @@ import argparse class ImageDemoOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--net_model_path', '-n', type=str, default="../models/last_transfered_net.pth", help="Net model path folder") - self.parser.add_argument('--lut_model_path', '-l', type=str, default="../models/last_transfered_lut.pth", help="Lut model path folder") + self.parser.add_argument('--model_paths', '-n', nargs='+', type=str, default=["../models/last_transfered_net.pth","../models/last_transfered_lut.pth"], help="Model paths for comparison") self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") - self.parser.add_argument('--project_path', type=str, default="../", help="Project path.") + self.parser.add_argument('--output_path', type=str, default="../models/", help="Project path.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") self.parser.add_argument('--mirror', action='store_true', default=False) def parse_args(self): args = self.parser.parse_args() - args.project_path = Path(args.project_path).resolve() + args.output_path = Path(args.output_path).resolve() args.hr_image_path = Path(args.hr_image_path).resolve() args.lr_image_path = Path(args.lr_image_path).resolve() - args.net_model_path = Path(args.net_model_path).resolve() - args.lut_model_path = Path(args.lut_model_path).resolve() + args.model_paths = [Path(x).resolve() for x in args.model_paths] return args def __repr__(self): - config = self.parser.parse_args() + config = self.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): @@ -45,12 +43,10 @@ config_inst = ImageDemoOptions() config = config_inst.parse_args() start_script_time = datetime.now() - -net_model = LoadCheckpoint(config.net_model_path).cuda() -lut_model = LoadCheckpoint(config.lut_model_path).cuda() - -print(net_model) -print(lut_model) +print(config_inst) +models = [LoadCheckpoint(x).cuda() for x in config.model_paths] +for m in models: + print(m) lr_image = cv2.imread(str(config.lr_image_path))[:,:,::-1] image_gt = cv2.imread(str(config.hr_image_path))[:,:,::-1] @@ -62,15 +58,36 @@ image_gt = image_gt.copy() input_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)[None,...].cuda() -with torch.inference_mode(): - net_prediction = net_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() - lut_prediction = lut_model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() - +predictions = [] +for model in models: + with torch.inference_mode(): + prediction = model(input_image).cpu().type(torch.uint8).squeeze().permute(1,2,0).numpy().copy() + predictions.append(prediction) image_gt = cv2.putText(image_gt, 'GT', org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) -image_net = cv2.putText(net_prediction, net_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) -image_lut = cv2.putText(lut_prediction, lut_model.__class__.__name__, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) +images_predicted = [] +for model_path, model, prediction in zip(config.model_paths, models, predictions): + prediction = cv2.putText(prediction, model_path.stem, org=(20, 50) , fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,255,255), thickness=2, lineType= cv2.LINE_AA) + images_predicted.append(prediction) + +image_count = 1 + len(images_predicted) +t = np.sqrt(image_count).astype(np.int32) +residual = image_count % t +if residual != 0: + column_count = image_count + row_count = 1 +else: + column_count = image_count // t + row_count = t +images = [image_gt] + images_predicted -Image.fromarray(np.concatenate([image_gt, image_net, image_lut], 1)).save(config.project_path / "models" / 'last_transfered_demo.png') +columns = [] +for i in range(row_count): + row = [] + for j in range(column_count): + row.append(images[i*column_count + j]) + columns.append(np.concatenate(row, axis=1)) +canvas = np.concatenate(columns, axis=0).astype(np.uint8) +Image.fromarray(canvas).save(config.output_path / 'image_demo.png') print(datetime.now() - start_script_time ) \ No newline at end of file