@ -11,25 +11,23 @@ import argparse
class ImageDemoOptions ( ) :
def __init__ ( self ) :
self . parser = argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
self . parser . add_argument ( ' --net_model_path ' , ' -n ' , type = str , default = " ../models/last_transfered_net.pth " , help = " Net model path folder " )
self . parser . add_argument ( ' --lut_model_path ' , ' -l ' , type = str , default = " ../models/last_transfered_lut.pth " , help = " Lut model path folder " )
self . parser . add_argument ( ' --model_paths ' , ' -n ' , nargs = ' + ' , type = str , default = [ " ../models/last_transfered_net.pth " , " ../models/last_transfered_lut.pth " ] , help = " Model paths for comparison " )
self . parser . add_argument ( ' --hr_image_path ' , ' -a ' , type = str , default = " ../data/Set14/HR/monarch.png " , help = " HR image path " )
self . parser . add_argument ( ' --lr_image_path ' , ' -b ' , type = str , default = " ../data/Set14/LR/X4/monarch.png " , help = " LR image path " )
self . parser . add_argument ( ' -- projec t_path' , type = str , default = " .. /" , help = " Project path. " )
self . parser . add_argument ( ' -- outpu t_path' , type = str , default = " .. /models /" , help = " Project path. " )
self . parser . add_argument ( ' --batch_size ' , type = int , default = 2 * * 10 , help = " Size of the batch for the input domain values. " )
self . parser . add_argument ( ' --mirror ' , action = ' store_true ' , default = False )
def parse_args ( self ) :
args = self . parser . parse_args ( )
args . project_path = Path ( args . projec t_path) . resolve ( )
args . output_path = Path ( args . outpu t_path) . resolve ( )
args . hr_image_path = Path ( args . hr_image_path ) . resolve ( )
args . lr_image_path = Path ( args . lr_image_path ) . resolve ( )
args . net_model_path = Path ( args . net_model_path ) . resolve ( )
args . lut_model_path = Path ( args . lut_model_path ) . resolve ( )
args . model_paths = [ Path ( x ) . resolve ( ) for x in args . model_paths ]
return args
def __repr__ ( self ) :
config = self . parse r. parse _args( )
config = self . parse _args( )
message = ' '
message + = ' ----------------- Options --------------- \n '
for k , v in sorted ( vars ( config ) . items ( ) ) :
@ -45,12 +43,10 @@ config_inst = ImageDemoOptions()
config = config_inst . parse_args ( )
start_script_time = datetime . now ( )
net_model = LoadCheckpoint ( config . net_model_path ) . cuda ( )
lut_model = LoadCheckpoint ( config . lut_model_path ) . cuda ( )
print ( net_model )
print ( lut_model )
print ( config_inst )
models = [ LoadCheckpoint ( x ) . cuda ( ) for x in config . model_paths ]
for m in models :
print ( m )
lr_image = cv2 . imread ( str ( config . lr_image_path ) ) [ : , : , : : - 1 ]
image_gt = cv2 . imread ( str ( config . hr_image_path ) ) [ : , : , : : - 1 ]
@ -62,15 +58,36 @@ image_gt = image_gt.copy()
input_image = torch . tensor ( lr_image ) . type ( torch . float32 ) . permute ( 2 , 0 , 1 ) [ None , . . . ] . cuda ( )
with torch . inference_mode ( ) :
net_prediction = net_model ( input_image ) . cpu ( ) . type ( torch . uint8 ) . squeeze ( ) . permute ( 1 , 2 , 0 ) . numpy ( ) . copy ( )
lut_prediction = lut_model ( input_image ) . cpu ( ) . type ( torch . uint8 ) . squeeze ( ) . permute ( 1 , 2 , 0 ) . numpy ( ) . copy ( )
predictions = [ ]
for model in models :
with torch . inference_mode ( ) :
prediction = model ( input_image ) . cpu ( ) . type ( torch . uint8 ) . squeeze ( ) . permute ( 1 , 2 , 0 ) . numpy ( ) . copy ( )
predictions . append ( prediction )
image_gt = cv2 . putText ( image_gt , ' GT ' , org = ( 20 , 50 ) , fontFace = cv2 . FONT_HERSHEY_SIMPLEX , fontScale = 1 , color = ( 255 , 255 , 255 ) , thickness = 2 , lineType = cv2 . LINE_AA )
image_net = cv2 . putText ( net_prediction , net_model . __class__ . __name__ , org = ( 20 , 50 ) , fontFace = cv2 . FONT_HERSHEY_SIMPLEX , fontScale = 1 , color = ( 255 , 255 , 255 ) , thickness = 2 , lineType = cv2 . LINE_AA )
image_lut = cv2 . putText ( lut_prediction , lut_model . __class__ . __name__ , org = ( 20 , 50 ) , fontFace = cv2 . FONT_HERSHEY_SIMPLEX , fontScale = 1 , color = ( 255 , 255 , 255 ) , thickness = 2 , lineType = cv2 . LINE_AA )
images_predicted = [ ]
for model_path , model , prediction in zip ( config . model_paths , models , predictions ) :
prediction = cv2 . putText ( prediction , model_path . stem , org = ( 20 , 50 ) , fontFace = cv2 . FONT_HERSHEY_SIMPLEX , fontScale = 1 , color = ( 255 , 255 , 255 ) , thickness = 2 , lineType = cv2 . LINE_AA )
images_predicted . append ( prediction )
image_count = 1 + len ( images_predicted )
t = np . sqrt ( image_count ) . astype ( np . int32 )
residual = image_count % t
if residual != 0 :
column_count = image_count
row_count = 1
else :
column_count = image_count / / t
row_count = t
images = [ image_gt ] + images_predicted
Image . fromarray ( np . concatenate ( [ image_gt , image_net , image_lut ] , 1 ) ) . save ( config . project_path / " models " / ' last_transfered_demo.png ' )
columns = [ ]
for i in range ( row_count ) :
row = [ ]
for j in range ( column_count ) :
row . append ( images [ i * column_count + j ] )
columns . append ( np . concatenate ( row , axis = 1 ) )
canvas = np . concatenate ( columns , axis = 0 ) . astype ( np . uint8 )
Image . fromarray ( canvas ) . save ( config . output_path / ' image_demo.png ' )
print ( datetime . now ( ) - start_script_time )