@ -8,15 +8,16 @@ import torch
from torch . utils . data import Dataset , DataLoader
from torch . utils . data . distributed import DistributedSampler
from pathlib import Path
from common . utils import PIL_CONVERT_COLOR , pil2numpy
image_extensions = [ ' .jpg ' , ' .png ' ]
def load_images_cached ( images_dir_path ):
def load_images_cached ( images_dir_path , color_model ):
image_paths = sorted ( [ f for f in Path ( images_dir_path ) . glob ( " * " ) if f . suffix . lower ( ) in image_extensions ] )
cache_path = Path ( images_dir_path ) . parent / f " { Path ( images_dir_path ) . stem } _ cache.npy"
cache_path = Path ( images_dir_path ) . parent / f " { Path ( images_dir_path ) . stem } _ { color_model } _ cache.npy"
cache_path = cache_path . resolve ( )
if not Path ( cache_path ) . exists ( ) :
print ( " Caching to: " , cache_path )
value = { f : np. array ( Image . open ( f ) ) for f in image_paths }
value = { f : pil2numpy( PIL_CONVERT_COLOR [ color_model ] ( Image . open ( f ) ) ) for f in image_paths }
np . save ( cache_path , value , allow_pickle = True )
else :
value = np . load ( cache_path , allow_pickle = True ) . item ( )
@ -24,12 +25,12 @@ def load_images_cached(images_dir_path):
return list ( value . keys ( ) ) , list ( value . values ( ) )
class SRTrainDataset ( Dataset ) :
def __init__ ( self , hr_dir_path , lr_dir_path , patch_size , rigid_aug= True ) :
def __init__ ( self , hr_dir_path , lr_dir_path , patch_size , color_model = " RGB " , rigid_aug= True ) :
super ( SRTrainDataset , self ) . __init__ ( )
self . sz = patch_size
self . rigid_aug = rigid_aug
self . hr_image_names , self . hr_images = load_images_cached ( hr_dir_path )
self . lr_image_names , self . lr_images = load_images_cached ( lr_dir_path )
self . hr_image_names , self . hr_images = load_images_cached ( hr_dir_path , color_model = color_model )
self . lr_image_names , self . lr_images = load_images_cached ( lr_dir_path , color_model = color_model )
assert len ( self . hr_images ) == len ( self . lr_images )
def __getitem__ ( self , idx ) :
@ -48,8 +49,8 @@ class SRTrainDataset(Dataset):
i = random . randint ( 0 , lr_image . shape [ 0 ] - self . sz )
j = random . randint ( 0 , lr_image . shape [ 1 ] - self . sz )
# c = random.choice([0, 1, 2])
if len ( hr_image . shape ) == 3 :
hr_patch = hr_image [
( i * scale ) : ( i * scale + self . sz * scale ) ,
( j * scale ) : ( j * scale + self . sz * scale ) ,
@ -61,6 +62,7 @@ class SRTrainDataset(Dataset):
:
]
if self . rigid_aug :
if random . uniform ( 0 , 1 ) < 0.5 :
hr_patch = np . fliplr ( hr_patch )
@ -85,10 +87,10 @@ class SRTrainDataset(Dataset):
return len ( self . hr_images )
class SRTestDataset ( Dataset ) :
def __init__ ( self , hr_dir_path , lr_dir_path ):
def __init__ ( self , hr_dir_path , lr_dir_path , color_model ):
super ( SRTestDataset , self ) . __init__ ( )
self . hr_image_paths , self . hr_images = load_images_cached ( hr_dir_path )
self . lr_image_paths , self . lr_images = load_images_cached ( lr_dir_path )
self . hr_image_paths , self . hr_images = load_images_cached ( hr_dir_path , color_model = color_model )
self . lr_image_paths , self . lr_images = load_images_cached ( lr_dir_path , color_model = color_model )
assert len ( self . hr_images ) == len ( self . lr_images )
def __getitem__ ( self , idx ) :