|
|
@ -6,19 +6,16 @@ from pathlib import Path
|
|
|
|
from PIL import Image
|
|
|
|
from PIL import Image
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
# @ray.remote(num_cpus=1, num_gpus=0.3)
|
|
|
|
|
|
|
|
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
|
|
|
|
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
|
|
|
|
with torch.inference_mode():
|
|
|
|
with torch.inference_mode():
|
|
|
|
start_time = time.perf_counter_ns()
|
|
|
|
start_time = time.perf_counter_ns()
|
|
|
|
# prepare lr_image
|
|
|
|
# prepare lr_image
|
|
|
|
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
|
|
|
|
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
|
|
|
|
lr_image = lr_image.unsqueeze(0).to(torch.device(device))
|
|
|
|
lr_image = lr_image.unsqueeze(0).to(torch.device(device))
|
|
|
|
b, c, h, w = lr_image.shape
|
|
|
|
|
|
|
|
lr_image = lr_image.reshape(b, c, h, w)
|
|
|
|
|
|
|
|
# predict
|
|
|
|
# predict
|
|
|
|
pred_lr_image = model(lr_image)
|
|
|
|
pred_lr_image = model(lr_image)
|
|
|
|
# postprocess
|
|
|
|
# postprocess
|
|
|
|
pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8)
|
|
|
|
pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8)
|
|
|
|
pred_lr_image = pred_lr_image.cpu().numpy()
|
|
|
|
pred_lr_image = pred_lr_image.cpu().numpy()
|
|
|
|
run_time_ns = time.perf_counter_ns() - start_time
|
|
|
|
run_time_ns = time.perf_counter_ns() - start_time
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
@ -28,12 +25,11 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu
|
|
|
|
|
|
|
|
|
|
|
|
# metrics
|
|
|
|
# metrics
|
|
|
|
hr_image = modcrop(hr_image, model.scale)
|
|
|
|
hr_image = modcrop(hr_image, model.scale)
|
|
|
|
left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
|
|
|
|
Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
|
|
|
|
lr_area = np.prod(lr_image.shape[-2:])
|
|
|
|
lr_area = np.prod(lr_image.shape[-2:])
|
|
|
|
return PSNR(left, right, model.scale), cal_ssim(left, right), run_time_ns, lr_area
|
|
|
|
return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area
|
|
|
|
|
|
|
|
|
|
|
|
def valid_steps(model, datasets, config, log_prefix=""):
|
|
|
|
def valid_steps(model, datasets, config, log_prefix=""):
|
|
|
|
# ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
|
|
|
|
|
|
|
|
dataset_names = list(datasets.keys())
|
|
|
|
dataset_names = list(datasets.keys())
|
|
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
results = []
|
|
|
@ -52,17 +48,11 @@ def valid_steps(model, datasets, config, log_prefix=""):
|
|
|
|
test_dataset = datasets[dataset_name]
|
|
|
|
test_dataset = datasets[dataset_name]
|
|
|
|
tasks = []
|
|
|
|
tasks = []
|
|
|
|
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
|
|
|
|
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
|
|
|
|
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None
|
|
|
|
output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
|
|
|
|
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
|
|
|
|
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
|
|
|
|
tasks.append(task)
|
|
|
|
tasks.append(task)
|
|
|
|
|
|
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
# ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
|
|
|
|
|
|
|
|
# while len(remaining_refs) > 0:
|
|
|
|
|
|
|
|
# print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
|
|
|
|
|
|
|
|
# ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)
|
|
|
|
|
|
|
|
# print("\r", end=" ")
|
|
|
|
|
|
|
|
# tasks = [ray.get(task) for task in tasks]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for psnr, ssim, run_time_ns, lr_area in tasks:
|
|
|
|
for psnr, ssim, run_time_ns, lr_area in tasks:
|
|
|
|
psnrs.append(psnr)
|
|
|
|
psnrs.append(psnr)
|
|
|
|