| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -5,7 +5,7 @@ from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from pathlib import Path
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from PIL import Image
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import time
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from datetime import timedelta
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from datetime import timedelta, datetime
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=None, device='cuda'):      
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    with torch.inference_mode():
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -43,7 +43,7 @@ def val_image_pair(model, hr_image, lr_image, color_model, output_image_path=Non
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        ssim = cal_ssim(Y_left, Y_right)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return psnr, ssim, run_time_ns, lr_area
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def valid_steps(model, datasets, config, log_prefix=""):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def valid_steps(model, datasets, config, log_prefix="", print_progress = False):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    dataset_names = list(datasets.keys())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    results = []
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -61,11 +61,16 @@ def valid_steps(model, datasets, config, log_prefix=""):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        test_dataset = datasets[dataset_name]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        tasks = []
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if print_progress:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            start_datetime = datetime.now()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for idx, (hr_image, lr_image, hr_image_path, lr_image_path) in enumerate(test_dataset):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            output_image_path = predictions_path / f'{Path(hr_image_path).stem}.png' if config.save_predictions else None
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            task = val_image_pair(model, hr_image, lr_image, color_model=config.color_model, output_image_path=output_image_path, device=config.device)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            tasks.append(task)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            if print_progress:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                print(f"\r{datetime.now()-start_datetime} {idx}/{len(test_dataset)} {hr_image_path}", end=" "*25)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if print_progress:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            print()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        total_time = time.time() - start_time 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for psnr, ssim, run_time_ns, lr_area in tasks:
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |