@ -92,6 +92,54 @@ class SDYNetx2(SRNetBase):
return F . mse_loss ( pred / 255 , target / 255 )
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
return loss_fn
class SDYNetx2Inv ( SRNetBase ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetx2Inv , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . stage1_S = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_S = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage2_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage2_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
output = 0.0
output + = self . forward_stage ( x , self . _extract_pattern_S , self . stage1_S )
output + = self . forward_stage ( x , self . _extract_pattern_D , self . stage1_D )
output + = self . forward_stage ( x , self . _extract_pattern_Y , self . stage1_Y )
output / = 3
x = output
output = 0.0
output + = self . forward_stage ( x , self . _extract_pattern_S , self . stage2_S )
output + = self . forward_stage ( x , self . _extract_pattern_D , self . stage2_D )
output + = self . forward_stage ( x , self . _extract_pattern_Y , self . stage2_Y )
output / = 3
x = ( output + x ) / 2
x = x . reshape ( b , c , h * self . scale , w * self . scale )
return x
def get_lut_model ( self , quantization_interval = 16 , batch_size = 2 * * 10 ) :
stage1_S = lut . transfer_2x2_input_SxS_output ( self . stage1_S , quantization_interval = quantization_interval , batch_size = batch_size )
stage1_D = lut . transfer_2x2_input_SxS_output ( self . stage1_D , quantization_interval = quantization_interval , batch_size = batch_size )
stage1_Y = lut . transfer_2x2_input_SxS_output ( self . stage1_Y , quantization_interval = quantization_interval , batch_size = batch_size )
stage2_S = lut . transfer_2x2_input_SxS_output ( self . stage2_S , quantization_interval = quantization_interval , batch_size = batch_size )
stage2_D = lut . transfer_2x2_input_SxS_output ( self . stage2_D , quantization_interval = quantization_interval , batch_size = batch_size )
stage2_Y = lut . transfer_2x2_input_SxS_output ( self . stage2_Y , quantization_interval = quantization_interval , batch_size = batch_size )
lut_model = sdylut . SDYLutx2 . init_from_numpy ( stage1_S , stage1_D , stage1_Y , stage2_S , stage2_D , stage2_Y )
return lut_model
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
class SDYNetx3 ( SRNetBase ) :
class SDYNetx3 ( SRNetBase ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetx3 , self ) . __init__ ( )
super ( SDYNetx3 , self ) . __init__ ( )
@ -151,7 +199,7 @@ class SDYNetx3(SRNetBase):
return F . mse_loss ( pred / 255 , target / 255 )
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
return loss_fn
class SDYNetR90x1 ( nn. Modul e) :
class SDYNetR90x1 ( SRNetBas e) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetR90x1 , self ) . __init__ ( )
super ( SDYNetR90x1 , self ) . __init__ ( )
self . scale = scale
self . scale = scale
@ -162,16 +210,6 @@ class SDYNetR90x1(nn.Module):
self . stage1_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward_stage ( self , x , scale , percieve_pattern , stage ) :
b , c , h , w = x . shape
x = percieve_pattern ( x )
x = stage ( x )
x = round_func ( x )
x = x . reshape ( b , c , h , w , scale , scale )
x = x . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = x . reshape ( b , c , h * scale , w * scale )
return x
def forward ( self , x , config = None ) :
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
x = x . reshape ( b * c , 1 , h , w )
@ -219,10 +257,7 @@ class SDYNetR90x2(SRNetBase):
b , c , h , w = x . shape
b , c , h , w = x . shape
x = x . view ( b * c , 1 , h , w )
x = x . view ( b * c , 1 , h , w )
output_1 = torch . zeros ( [ b * c , 1 , h , w ] , dtype = x . dtype , device = x . device )
output_1 = torch . zeros ( [ b * c , 1 , h , w ] , dtype = x . dtype , device = x . device )
output_1 + = self . forward_stage ( x , 1 , self . _extract_pattern_S , self . stage1_S )
for rotations_count in range ( 4 ) :
output_1 + = self . forward_stage ( x , 1 , self . _extract_pattern_D , self . stage1_D )
output_1 + = self . forward_stage ( x , 1 , self . _extract_pattern_Y , self . stage1_Y )
for rotations_count in range ( 1 , 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_S , self . stage1_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_S , self . stage1_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_D , self . stage1_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_D , self . stage1_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
@ -230,10 +265,7 @@ class SDYNetR90x2(SRNetBase):
output_1 / = 4 * 3
output_1 / = 4 * 3
x = output_1
x = output_1
output_2 = torch . zeros ( [ b * c , 1 , h * self . scale , w * self . scale ] , dtype = x . dtype , device = x . device )
output_2 = torch . zeros ( [ b * c , 1 , h * self . scale , w * self . scale ] , dtype = x . dtype , device = x . device )
output_2 + = self . forward_stage ( x , self . scale , self . _extract_pattern_S , self . stage2_S )
for rotations_count in range ( 4 ) :
output_2 + = self . forward_stage ( x , self . scale , self . _extract_pattern_D , self . stage2_D )
output_2 + = self . forward_stage ( x , self . scale , self . _extract_pattern_Y , self . stage2_Y )
for rotations_count in range ( 1 , 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_S , self . stage2_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_S , self . stage2_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_D , self . stage2_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_D , self . stage2_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
@ -257,3 +289,249 @@ class SDYNetR90x2(SRNetBase):
def loss_fn ( pred , target ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
return loss_fn
class SDYEHONetR90x1 ( SRNetBase ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYEHONetR90x1 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_E = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 0 ] , [ 0 , 3 ] , [ 3 , 3 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . _extract_pattern_H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . _extract_pattern_O = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 1 ] , [ 2 , 2 ] , [ 1 , 3 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . stage1_S = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_E = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_H = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_O = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . view ( b * c , 1 , h , w )
output_1 = torch . zeros ( [ b * c , 1 , h * self . scale , w * self . scale ] , dtype = x . dtype , device = x . device )
for rotations_count in range ( 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_S , self . stage1_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_D , self . stage1_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_Y , self . stage1_Y ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_E , self . stage1_E ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_H , self . stage1_H ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_O , self . stage1_O ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 / = 4 * 6
x = output_1
x = x . view ( b , c , h * self . scale , w * self . scale )
return x
def get_lut_model ( self , quantization_interval = 16 , batch_size = 2 * * 10 ) :
raise NotImplementedError
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
class SDYEHONetR90x2 ( SRNetBase ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYEHONetR90x2 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_E = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 0 ] , [ 0 , 3 ] , [ 3 , 3 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . _extract_pattern_H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . _extract_pattern_O = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 1 ] , [ 2 , 2 ] , [ 1 , 3 ] ] , center = [ 0 , 0 ] , window_size = 4 )
self . stage1_S = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_E = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_H = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_O = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage2_S = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_D = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_Y = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_E = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_H = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage2_O = layers . UpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . view ( b * c , 1 , h , w )
output_1 = torch . zeros ( [ b * c , 1 , h , w ] , dtype = x . dtype , device = x . device )
for rotations_count in range ( 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_S , self . stage1_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_D , self . stage1_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_Y , self . stage1_Y ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_E , self . stage1_E ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_H , self . stage1_H ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 + = torch . rot90 ( self . forward_stage ( rotated , 1 , self . _extract_pattern_O , self . stage1_O ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_1 / = 4 * 6
x = output_1
output_2 = torch . zeros ( [ b * c , 1 , h * self . scale , w * self . scale ] , dtype = x . dtype , device = x . device )
for rotations_count in range ( 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_S , self . stage2_S ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_D , self . stage2_D ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_Y , self . stage2_Y ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_E , self . stage2_E ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_H , self . stage2_H ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 + = torch . rot90 ( self . forward_stage ( rotated , self . scale , self . _extract_pattern_O , self . stage2_O ) , k = - rotations_count , dims = [ - 2 , - 1 ] )
output_2 / = 4 * 6
x = output_2
x = x . view ( b , c , h * self . scale , w * self . scale )
return x
def get_lut_model ( self , quantization_interval = 16 , batch_size = 2 * * 10 ) :
raise NotImplementedError
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
class SDYMixNetx1 ( SRNetBase ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYMixNetx1 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_1 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] , [ 2 , 2 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . _extract_pattern_2 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 1 ] , [ 3 , 3 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . _extract_pattern_3 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 2 , 2 ] , [ 3 , 2 ] , [ 2 , 3 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . _extract_pattern_4 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . _extract_pattern_5 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 0 , 2 ] , [ 0 , 3 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . _extract_pattern_6 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 2 , 0 ] , [ 3 , 0 ] ] , center = [ 1 , 1 ] , window_size = 4 )
self . stage1_1 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_2 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_3 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_4 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_5 = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_6 = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_Mix = layers . UpscaleBlockChebyKAN ( in_features = 6 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
output = self . forward_stage ( x , self . _extract_pattern_1 , self . stage1_1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_2 , self . stage1_2 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_3 , self . stage1_3 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_4 , self . stage1_4 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_5 , self . stage1_5 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_6 , self . stage1_6 ) ] , dim = 1 )
output = output . permute ( 0 , 2 , 3 , 1 ) . view ( b , h * w , 6 )
output = self . stage1_Mix ( output )
output = output . view ( b , 1 , h , w , self . scale , self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = output
x = x . reshape ( b , c , h * self . scale , w * self . scale )
return x
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
class SDYMixNetx1v2 ( SRNetBase ) :
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYMixNetx1v2 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_1 = layers . PercievePattern ( receptive_field_idxes = [ [ 2 , 2 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_2 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] , [ 2 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_3 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 3 ] , [ 3 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_4 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 0 ] , [ 1 , 4 ] , [ 3 , 4 ] , [ 3 , 0 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_5 = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 1 ] , [ 0 , 3 ] , [ 4 , 3 ] , [ 4 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_6 = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 4 ] , [ 4 , 4 ] , [ 4 , 0 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . stage1_1 = layers . UpscaleBlock ( in_features = 1 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_2 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_3 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_4 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_5 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_6 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stage1_Mix = layers . UpscaleBlockChebyKAN ( in_features = 6 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
output = self . forward_stage ( x , self . _extract_pattern_1 , self . stage1_1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_2 , self . stage1_2 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_3 , self . stage1_3 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_4 , self . stage1_4 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_5 , self . stage1_5 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_6 , self . stage1_6 ) ] , dim = 1 )
output = output . permute ( 0 , 2 , 3 , 1 ) . view ( b , h * w , 6 )
output = self . stage1_Mix ( output )
output = output . view ( b , 1 , h , w , self . scale , self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = output
x = x . reshape ( b , c , h * self . scale , w * self . scale )
return x
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn
class SDYMixNetx1v3 ( SRNetBase ) :
"""
22
12 23 32 21
11 13 33 31
10 14 34 30
01 03 43 41
00 04 44 40
"""
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYMixNetx1v3 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_0 = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] ] , center = [ 0 , 0 ] , window_size = 1 , channels = 6 )
self . _extract_pattern_1 = layers . PercievePattern ( receptive_field_idxes = [ [ 2 , 2 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_2 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] , [ 2 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_3 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 3 ] , [ 3 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_4 = layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 0 ] , [ 1 , 4 ] , [ 3 , 4 ] , [ 3 , 0 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_5 = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 1 ] , [ 0 , 3 ] , [ 4 , 3 ] , [ 4 , 1 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . _extract_pattern_6 = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 4 ] , [ 4 , 4 ] , [ 4 , 0 ] ] , center = [ 2 , 2 ] , window_size = 5 )
self . stage1_1 = layers . UpscaleBlock ( in_features = 1 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_2 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_3 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_4 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_5 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_6 = layers . UpscaleBlock ( in_features = 4 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stage1_Mix = layers . UpscaleBlockChebyKAN ( in_features = 6 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
output = self . forward_stage ( x , self . _extract_pattern_1 , self . stage1_1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_2 , self . stage1_2 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_3 , self . stage1_3 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_4 , self . stage1_4 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_5 , self . stage1_5 ) ] , dim = 1 )
output = torch . cat ( [ output , self . forward_stage ( x , self . _extract_pattern_6 , self . stage1_6 ) ] , dim = 1 )
output = self . forward_stage ( output , self . _extract_pattern_0 , self . stage1_Mix )
x = output
x = x . reshape ( b , c , h * self . scale , w * self . scale )
return x
def get_loss_fn ( self ) :
def loss_fn ( pred , target ) :
return F . mse_loss ( pred / 255 , target / 255 )
return loss_fn