From e9c624894978c6257bedc4db90b50073a3047afe Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 10:53:04 +0000 Subject: [PATCH] cleanup val script --- src/common/validation.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/common/validation.py b/src/common/validation.py index c06bb1d..97f58f8 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -6,19 +6,16 @@ from pathlib import Path from PIL import Image import time -# @ray.remote(num_cpus=1, num_gpus=0.3) def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'): with torch.inference_mode(): start_time = time.perf_counter_ns() # prepare lr_image lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = lr_image.unsqueeze(0).to(torch.device(device)) - b, c, h, w = lr_image.shape - lr_image = lr_image.reshape(b, c, h, w) # predict pred_lr_image = model(lr_image) # postprocess - pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8) + pred_lr_image = pred_lr_image.squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.cpu().numpy() run_time_ns = time.perf_counter_ns() - start_time torch.cuda.empty_cache() @@ -28,12 +25,11 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cu # metrics hr_image = modcrop(hr_image, model.scale) - left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] + Y_left, Y_right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] lr_area = np.prod(lr_image.shape[-2:]) - return PSNR(left, right, model.scale), cal_ssim(left, right), run_time_ns, lr_area + return PSNR(Y_left, Y_right, model.scale), cal_ssim(Y_left, Y_right), run_time_ns, lr_area def valid_steps(model, datasets, config, log_prefix=""): - # ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) dataset_names = list(datasets.keys()) results = [] @@ -52,17 +48,11 @@ def valid_steps(model, datasets, config, log_prefix=""): test_dataset = datasets[dataset_name] tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: - output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None + output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device) tasks.append(task) total_time = time.time() - start_time - # ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) - # while len(remaining_refs) > 0: - # print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") - # ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None) - # print("\r", end=" ") - # tasks = [ray.get(task) for task in tasks] for psnr, ssim, run_time_ns, lr_area in tasks: psnrs.append(psnr)