@ -8,15 +8,16 @@ import torch
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				from  torch . utils . data  import  Dataset ,  DataLoader 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				from  torch . utils . data . distributed  import  DistributedSampler 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				from  pathlib  import  Path 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				from  common . utils  import  PIL_CONVERT_COLOR ,  pil2numpy 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				image_extensions  =  [ ' .jpg ' ,  ' .png ' ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  load_images_cached ( images_dir_path  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  load_images_cached ( images_dir_path , color_model  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    image_paths  =  sorted ( [ f  for  f  in  Path ( images_dir_path ) . glob ( " * " )  if  f . suffix . lower ( )  in  image_extensions ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    cache_path  =  Path ( images_dir_path ) . parent  /  f " { Path ( images_dir_path ) . stem } _ cache.npy" 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    cache_path  =  Path ( images_dir_path ) . parent  /  f " { Path ( images_dir_path ) . stem } _ { color_model } _  cache.npy" 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    cache_path  =  cache_path . resolve ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    if  not  Path ( cache_path ) . exists ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        print ( " Caching to: " ,  cache_path ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        value  =  { f : np. array  ( Image . open ( f ) )  for  f  in  image_paths } 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        value  =  { f : pil2numpy( PIL_CONVERT_COLOR [ color_model ]  ( Image . open ( f ) ) )  for  f  in  image_paths }   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        np . save ( cache_path ,  value ,  allow_pickle = True )         
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    else : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        value  =  np . load ( cache_path ,  allow_pickle = True ) . item ( ) 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -24,12 +25,12 @@ def load_images_cached(images_dir_path):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  list ( value . keys ( ) ) ,  list ( value . values ( ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SRTrainDataset ( Dataset ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hr_dir_path ,  lr_dir_path ,  patch_size ,   rigid_aug= True ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hr_dir_path ,  lr_dir_path ,  patch_size ,  color_model =  " RGB " ,   rigid_aug= True ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SRTrainDataset ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . sz  =  patch_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . rigid_aug  =  rigid_aug 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hr_image_names ,  self . hr_images  =  load_images_cached ( hr_dir_path  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . lr_image_names ,  self . lr_images  =  load_images_cached ( lr_dir_path  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hr_image_names ,  self . hr_images  =  load_images_cached ( hr_dir_path , color_model = color_model  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . lr_image_names ,  self . lr_images  =  load_images_cached ( lr_dir_path , color_model = color_model  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  len ( self . hr_images )  ==  len ( self . lr_images ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __getitem__ ( self ,  idx ) : 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -48,18 +49,19 @@ class SRTrainDataset(Dataset):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            i  =  random . randint ( 0 ,  lr_image . shape [ 0 ]  -  self . sz ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            j  =  random . randint ( 0 ,  lr_image . shape [ 1 ]  -  self . sz ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            # c = random.choice([0, 1, 2]) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            hr_patch  =  hr_image [ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ( i * scale ) : ( i * scale  +  self . sz * scale ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ( j * scale ) : ( j * scale  +  self . sz * scale ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            lr_patch  =  lr_image [ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                i : ( i  +  self . sz ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                j : ( j  +  self . sz ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  len ( hr_image . shape )  ==  3 : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                hr_patch  =  hr_image [ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    ( i * scale ) : ( i * scale  +  self . sz * scale ) , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    ( j * scale ) : ( j * scale  +  self . sz * scale ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                lr_patch  =  lr_image [ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    i : ( i  +  self . sz ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    j : ( j  +  self . sz ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                    : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  self . rigid_aug : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				                if  random . uniform ( 0 ,  1 )  <  0.5 : 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -85,10 +87,10 @@ class SRTrainDataset(Dataset):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  len ( self . hr_images ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SRTestDataset ( Dataset ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hr_dir_path ,  lr_dir_path  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hr_dir_path ,  lr_dir_path , color_model  ): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SRTestDataset ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hr_image_paths ,  self . hr_images  =  load_images_cached ( hr_dir_path  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . lr_image_paths ,  self . lr_images  =  load_images_cached ( lr_dir_path  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hr_image_paths ,  self . hr_images  =  load_images_cached ( hr_dir_path , color_model = color_model  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . lr_image_paths ,  self . lr_images  =  load_images_cached ( lr_dir_path , color_model = color_model  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  len ( self . hr_images )  ==  len ( self . lr_images ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __getitem__ ( self ,  idx ) :