@ -92,6 +92,54 @@ class SDYNetx2(SRNetBase):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYNetx2Inv ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYNetx2Inv ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_S  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_D  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_Y  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_S  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_S  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  0.0 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_S ,  self . stage1_S ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_D ,  self . stage1_D ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_Y ,  self . stage1_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  / =  3 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  0.0 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_S ,  self . stage2_S ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_D ,  self . stage2_D ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  + =  self . forward_stage ( x ,  self . _extract_pattern_Y ,  self . stage2_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  / =  3 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  ( output  +  x ) / 2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_lut_model ( self ,  quantization_interval = 16 ,  batch_size = 2 * * 10 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage1_S  =  lut . transfer_2x2_input_SxS_output ( self . stage1_S ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage1_D  =  lut . transfer_2x2_input_SxS_output ( self . stage1_D ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage1_Y  =  lut . transfer_2x2_input_SxS_output ( self . stage1_Y ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage2_S  =  lut . transfer_2x2_input_SxS_output ( self . stage2_S ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage2_D  =  lut . transfer_2x2_input_SxS_output ( self . stage2_D ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        stage2_Y  =  lut . transfer_2x2_input_SxS_output ( self . stage2_Y ,  quantization_interval = quantization_interval ,  batch_size = batch_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        lut_model  =  sdylut . SDYLutx2 . init_from_numpy ( stage1_S ,  stage1_D ,  stage1_Y ,  stage2_S ,  stage2_D ,  stage2_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  lut_model 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYNetx3 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYNetx3 ,  self ) . __init__ ( ) 
 
			
		 
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
			
			 
			 
			
				@ -151,7 +199,7 @@ class SDYNetx3(SRNetBase):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYNetR90x1 ( nn. Modul  e) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYNetR90x1 ( SRNetBas e) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYNetR90x1 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -162,16 +210,6 @@ class SDYNetR90x1(nn.Module):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward_stage ( self ,  x ,  scale ,  percieve_pattern ,  stage ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  percieve_pattern ( x )    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  stage ( x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  round_func ( x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h ,  w ,  scale ,  scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * scale ,  w * scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
			
			 
			 
			
				@ -219,10 +257,7 @@ class SDYNetR90x2(SRNetBase):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . view ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  =  torch . zeros ( [ b * c ,  1 ,  h ,  w ] ,  dtype = x . dtype ,  device = x . device ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  + =  self . forward_stage ( x ,  1 ,  self . _extract_pattern_S ,  self . stage1_S ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  + =  self . forward_stage ( x ,  1 ,  self . _extract_pattern_D ,  self . stage1_D ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  + =  self . forward_stage ( x ,  1 ,  self . _extract_pattern_Y ,  self . stage1_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 1 , 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            rotated  =  torch . rot90 ( x ,  k = rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_S ,  self . stage1_S ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_D ,  self . stage1_D ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -230,10 +265,7 @@ class SDYNetR90x2(SRNetBase):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  / =  4 * 3 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output_1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  =  torch . zeros ( [ b * c ,  1 ,  h * self . scale ,  w * self . scale ] ,  dtype = x . dtype ,  device = x . device ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  + =  self . forward_stage ( x ,  self . scale ,  self . _extract_pattern_S ,  self . stage2_S ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  + =  self . forward_stage ( x ,  self . scale ,  self . _extract_pattern_D ,  self . stage2_D ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  + =  self . forward_stage ( x ,  self . scale ,  self . _extract_pattern_Y ,  self . stage2_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 1 , 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            rotated  =  torch . rot90 ( x ,  k = rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_S ,  self . stage2_S ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_D ,  self . stage2_D ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -253,6 +285,252 @@ class SDYNetR90x2(SRNetBase):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        lut_model  =  sdylut . SDYLutR90x2 . init_from_numpy ( stage1_S ,  stage1_D ,  stage1_Y ,  stage2_S ,  stage2_D ,  stage2_Y ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  lut_model 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYEHONetR90x1 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYEHONetR90x1 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_S  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_D  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_Y  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_E  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 0 ] , [ 0 , 3 ] , [ 3 , 3 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_H  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_O  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 1 ] , [ 2 , 2 ] , [ 1 , 3 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_S  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_E  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_H  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_O  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . view ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  =  torch . zeros ( [ b * c ,  1 ,  h * self . scale ,  w * self . scale ] ,  dtype = x . dtype ,  device = x . device ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            rotated  =  torch . rot90 ( x ,  k = rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_S ,  self . stage1_S ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_D ,  self . stage1_D ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_Y ,  self . stage1_Y ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_E ,  self . stage1_E ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_H ,  self . stage1_H ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_O ,  self . stage1_O ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  / =  4 * 6 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output_1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . view ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_lut_model ( self ,  quantization_interval = 16 ,  batch_size = 2 * * 10 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        raise  NotImplementedError 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn         
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYEHONetR90x2 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYEHONetR90x2 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_S  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_D  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_Y  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 3 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_E  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 0 ] , [ 0 , 3 ] , [ 3 , 3 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_H  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_O  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 3 , 1 ] , [ 2 , 2 ] , [ 1 , 3 ] ] ,  center = [ 0 , 0 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_S  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_E  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_H  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_O  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_S  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_D  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_Y  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_E  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_H  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage2_O  =  layers . UpscaleBlock ( hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . view ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  =  torch . zeros ( [ b * c ,  1 ,  h ,  w ] ,  dtype = x . dtype ,  device = x . device ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            rotated  =  torch . rot90 ( x ,  k = rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_S ,  self . stage1_S ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_D ,  self . stage1_D ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_Y ,  self . stage1_Y ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_E ,  self . stage1_E ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_H ,  self . stage1_H ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_1  + =  torch . rot90 ( self . forward_stage ( rotated ,  1 ,  self . _extract_pattern_O ,  self . stage1_O ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] )               
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_1  / =  4 * 6 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output_1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  =  torch . zeros ( [ b * c ,  1 ,  h * self . scale ,  w * self . scale ] ,  dtype = x . dtype ,  device = x . device ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  rotations_count  in  range ( 4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            rotated  =  torch . rot90 ( x ,  k = rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_S ,  self . stage2_S ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_D ,  self . stage2_D ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_Y ,  self . stage2_Y ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_E ,  self . stage2_E ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_H ,  self . stage2_H ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            output_2  + =  torch . rot90 ( self . forward_stage ( rotated ,  self . scale ,  self . _extract_pattern_O ,  self . stage2_O ) ,  k = - rotations_count ,  dims = [ - 2 ,  - 1 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output_2  / =  4 * 6 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output_2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . view ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_lut_model ( self ,  quantization_interval = 16 ,  batch_size = 2 * * 10 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        raise  NotImplementedError 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn         
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYMixNetx1 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYMixNetx1 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_1  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] , [ 2 , 2 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_2  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 1 ] , [ 3 , 3 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_3  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 2 , 2 ] , [ 3 , 2 ] , [ 2 , 3 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_4  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_5  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 0 , 2 ] , [ 0 , 3 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_6  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 2 , 0 ] , [ 3 , 0 ] ] ,  center = [ 1 , 1 ] ,  window_size = 4 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_1  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_2  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_3  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_4  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_5  =  layers . UpscaleBlock ( in_features = 3 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_6  =  layers . UpscaleBlock ( in_features = 3 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Mix  =  layers . UpscaleBlockChebyKAN ( in_features = 6 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . forward_stage ( x ,  self . _extract_pattern_1 ,  self . stage1_1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_2 ,  self . stage1_2 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_3 ,  self . stage1_3 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_4 ,  self . stage1_4 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_5 ,  self . stage1_5 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_6 ,  self . stage1_6 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  output . permute ( 0 ,  2 ,  3 ,  1 ) . view ( b ,  h * w ,  6 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . stage1_Mix ( output ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  output . view ( b ,  1 ,  h ,  w ,  self . scale ,  self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYMixNetx1v2 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    """ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    22 			
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    12 	23 	32 	21 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    11 	13 	33 	31 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    10 	14 	34 	30 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    01 	03 	43 	41 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    00 	04 	44 	40 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    """ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYMixNetx1v2 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_1  =  layers . PercievePattern ( receptive_field_idxes = [ [ 2 , 2 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_2  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] , [ 2 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_3  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 3 ] , [ 3 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_4  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 0 ] , [ 1 , 4 ] , [ 3 , 4 ] , [ 3 , 0 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_5  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 1 ] , [ 0 , 3 ] , [ 4 , 3 ] , [ 4 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_6  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 4 ] , [ 4 , 4 ] , [ 4 , 0 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_1  =  layers . UpscaleBlock ( in_features = 1 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_2  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_3  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_4  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_5  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_6  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Mix  =  layers . UpscaleBlockChebyKAN ( in_features = 6 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . forward_stage ( x ,  self . _extract_pattern_1 ,  self . stage1_1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_2 ,  self . stage1_2 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_3 ,  self . stage1_3 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_4 ,  self . stage1_4 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_5 ,  self . stage1_5 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_6 ,  self . stage1_6 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  output . permute ( 0 ,  2 ,  3 ,  1 ) . view ( b ,  h * w ,  6 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . stage1_Mix ( output ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  output . view ( b ,  1 ,  h ,  w ,  self . scale ,  self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  loss_fn 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  SDYMixNetx1v3 ( SRNetBase ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    """ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    22 			
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    12 	23 	32 	21 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    11 	13 	33 	31 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    10 	14 	34 	30 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    01 	03 	43 	41 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    00 	04 	44 	40 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    """ 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  64 ,  layers_count  =  4 ,  scale  =  4 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( SDYMixNetx1v3 ,  self ) . __init__ ( ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . scale  =  scale    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_0  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] ] ,  center = [ 0 , 0 ] ,  window_size = 1 ,  channels = 6 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_1  =  layers . PercievePattern ( receptive_field_idxes = [ [ 2 , 2 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_2  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 2 ] , [ 2 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_3  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 1 ] , [ 1 , 3 ] , [ 3 , 3 ] , [ 3 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_4  =  layers . PercievePattern ( receptive_field_idxes = [ [ 1 , 0 ] , [ 1 , 4 ] , [ 3 , 4 ] , [ 3 , 0 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_5  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 1 ] , [ 0 , 3 ] , [ 4 , 3 ] , [ 4 , 1 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . _extract_pattern_6  =  layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 4 ] , [ 4 , 4 ] , [ 4 , 0 ] ] ,  center = [ 2 , 2 ] ,  window_size = 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_1  =  layers . UpscaleBlock ( in_features = 1 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_2  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_3  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_4  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_5  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_6  =  layers . UpscaleBlock ( in_features = 4 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . stage1_Mix  =  layers . UpscaleBlockChebyKAN ( in_features = 6 ,  hidden_dim = hidden_dim ,  layers_count = layers_count ,  upscale_factor = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ,  config = None ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b * c ,  1 ,  h ,  w ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . forward_stage ( x ,  self . _extract_pattern_1 ,  self . stage1_1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_2 ,  self . stage1_2 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_3 ,  self . stage1_3 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_4 ,  self . stage1_4 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_5 ,  self . stage1_5 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  torch . cat ( [ output ,  self . forward_stage ( x ,  self . _extract_pattern_6 ,  self . stage1_6 ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        output  =  self . forward_stage ( output ,  self . _extract_pattern_0 ,  self . stage1_Mix ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  output 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * self . scale ,  w * self . scale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  get_loss_fn ( self ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        def  loss_fn ( pred ,  target ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            return  F . mse_loss ( pred / 255 ,  target / 255 )