@ -7,6 +7,7 @@ from .utils import  round_func
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  PercievePattern ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 2 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  window_size  > =  ( np . max ( receptive_field_idxes ) + 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  len ( receptive_field_idxes )  ==  4 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . receptive_field_idxes  =  np . array ( receptive_field_idxes ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . window_size  =  window_size 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . center  =  center 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -32,73 +33,38 @@ class PercievePattern():
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x [ : , self . receptive_field_idxes [ 2 ] , : ] , 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x [ : , self . receptive_field_idxes [ 3 ] , : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ] ,  2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( x . shape [ 0 ] * x . shape [ 1 ] ,  1 ,  2 ,  2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. –   2017. –   С  . 4700-4708. 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# https://ar5iv.labs.arxiv.org/html/1608.06993 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# refactoring to linear give slight speed up, but require total rewrite to be consistent 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  DenseConvUpscaleBlock ( nn . Module ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  32 ,  layers_count = 5 ,  upscale_factor = 1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( DenseConvUpscaleBlock ,  self ) . __init__ ( )    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  UpscaleBlock ( nn . Module ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] ,  center = [ 0 , 0 ] ,  window_size = 2 ,  in_features = 4 ,  hidden_dim  =  32 ,  layers_count = 5 ,  upscale_factor = 1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( UpscaleBlock ,  self ) . __init__ ( )    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  layers_count  >  0      
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . percieve_pattern  =  PercievePattern ( receptive_field_idxes = receptive_field_idxes ,  center = center ,  window_size = window_size ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . upscale_factor  =  upscale_factor  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hidden_dim  =  hidden_dim 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . embed  =  nn . Conv2d( 1 ,  hidden_dim ,  kernel_size = ( 2 ,  2 ) ,  padding = ' valid ' ,   stride = 1 ,  dilation = 1  ,  bias = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . embed  =  nn . Linear( in_features = in_features ,  out_features = hidden_dim  ,  bias = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self .  conv s =  [ ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . linear_proje cti ons =  [ ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  i  in  range ( layers_count ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self .  conv s. append ( nn . Conv2d( in_channels  =  ( i + 1 ) * hidden_dim ,  out_channels  =  hidden_dim ,  kernel_size  =  1 ,  stride = 1 ,  padding = 0 ,  dilation = 1  ,  bias = True ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self .  conv s =  nn . ModuleList ( self .  conv s)           
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . linear_proje cti ons. append ( nn . Linear( in_features = ( i + 1 ) * hidden_dim ,  out_features = hidden_dim  ,  bias = True ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . linear_proje cti ons =  nn . ModuleList ( self . linear_proje cti ons)           
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  name ,  p  in  self . named_parameters ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  " weight "  in  name :  nn . init . kaiming_normal_ ( p ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  " bias "  in  name :  nn . init . constant_ ( p ,  0 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . project_channels  =  nn . Conv2d ( in_channels  =  ( layers_count + 1 ) * hidden_dim ,  out_channels  =  upscale_factor  *  upscale_factor ,  kernel_size  =  1 ,  stride = 1 ,  padding = 0 ,  dilation = 1 ,  bias = True )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . shuffle  =  nn . PixelShuffle ( upscale_factor ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  ( x - 127.5 ) / 127.5     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . relu ( self . embed ( x ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  conv  in  self . convs : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x  =  torch . cat ( [ x ,  torch . relu ( conv ( x ) ) ] ,  dim = 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  self . shuffle ( self . project_channels ( x ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . tanh ( x )          
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x * 127.5  +  127.5  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  round_func ( x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				class  ConvUpscaleBlock ( nn . Module ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  __init__ ( self ,  hidden_dim  =  32 ,  layers_count = 5 ,  upscale_factor = 1 ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        super ( ConvUpscaleBlock ,  self ) . __init__ ( )    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        assert  layers_count  >  0      
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . upscale_factor  =  upscale_factor  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . hidden_dim  =  hidden_dim 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . embed  =  nn . Conv2d ( 1 ,  hidden_dim ,  kernel_size = ( 2 ,  2 ) ,  padding = ' valid ' ,   stride = 1 ,  dilation = 1 ,  bias = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . convs  =  [ ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  i  in  range ( layers_count ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            self . convs . append ( nn . Conv2d ( in_channels  =  hidden_dim ,  out_channels  =  hidden_dim ,  kernel_size  =  1 ,  stride = 1 ,  padding = 0 ,  dilation = 1 ,  bias = True ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . convs  =  nn . ModuleList ( self . convs )           
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  name ,  p  in  self . named_parameters ( ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  " weight "  in  name :  nn . init . kaiming_normal_ ( p ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            if  " bias "  in  name :  nn . init . constant_ ( p ,  0 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . project_channels  =  nn . Conv2d ( in_channels  =  hidden_dim ,  out_channels  =  upscale_factor  *  upscale_factor ,  kernel_size  =  1 ,  stride = 1 ,  padding = 0 ,  dilation = 1 ,  bias = True )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . shuffle  =  nn . PixelShuffle ( upscale_factor ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        self . project_channels  =  nn . Linear ( in_features = ( layers_count + 1 ) * hidden_dim ,  out_features = upscale_factor  *  upscale_factor ,  bias = True ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    def  forward ( self ,  x ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        b , c , h , w  =  x . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  ( x - 127.5 ) / 127.5     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  self . percieve_pattern ( x )         
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . relu ( self . embed ( x ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for   conv  in  self .  conv s: 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x  =  torch .  relu(  conv ( x ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  self . shuffle( self .  project_channels( x ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  linear_projection  in  self . linear_projections : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				            x  =  torch . cat ( [ x ,  torch . relu ( linear_projection ( x ) ) ] ,  dim = 2 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  self . project_channels ( x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  torch . tanh ( x )          
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x * 127.5  +  127.5  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  round_func ( x ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h ,  w ,  self . upscale_factor ,  self . upscale_factor ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        x  =  x . reshape ( b ,  c ,  h * self . upscale_factor ,  w * self . upscale_factor )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        return  x   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105