@ -7,6 +7,7 @@ from .utils import round_func
class PercievePattern ( ) :
def __init__ ( self , receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 ) :
assert window_size > = ( np . max ( receptive_field_idxes ) + 1 )
assert len ( receptive_field_idxes ) == 4
self . receptive_field_idxes = np . array ( receptive_field_idxes )
self . window_size = window_size
self . center = center
@ -32,73 +33,38 @@ class PercievePattern():
x [ : , self . receptive_field_idxes [ 2 ] , : ] ,
x [ : , self . receptive_field_idxes [ 3 ] , : ]
] , 2 )
x = x . reshape ( x . shape [ 0 ] * x . shape [ 1 ] , 1 , 2 , 2 )
return x
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С . 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
# refactoring to linear give slight speed up, but require total rewrite to be consistent
class DenseConvUpscaleBlock ( nn . Module ) :
def __init__ ( self , hidden_dim = 32 , layers_count = 5 , upscale_factor = 1 ) :
super ( DenseConvUpscaleBlock , self ) . __init__ ( )
class UpscaleBlock ( nn . Module ) :
def __init__ ( self , receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 , in_features = 4 , hidden_dim = 32 , layers_count = 5 , upscale_factor = 1 ) :
super ( UpscaleBlock , self ) . __init__ ( )
assert layers_count > 0
self . percieve_pattern = PercievePattern ( receptive_field_idxes = receptive_field_idxes , center = center , window_size = window_size )
self . upscale_factor = upscale_factor
self . hidden_dim = hidden_dim
self . embed = nn . Conv2d( 1 , hidden_dim , kernel_size = ( 2 , 2 ) , padding = ' valid ' , stride = 1 , dilation = 1 , bias = True )
self . embed = nn . Linear( in_features = in_features , out_features = hidden_dim , bias = True )
self . conv s = [ ]
self . linear_proje cti ons = [ ]
for i in range ( layers_count ) :
self . conv s. append ( nn . Conv2d( in_channels = ( i + 1 ) * hidden_dim , out_channels = hidden_dim , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True ) )
self . conv s = nn . ModuleList ( self . conv s)
self . linear_proje cti ons. append ( nn . Linear( in_features = ( i + 1 ) * hidden_dim , out_features = hidden_dim , bias = True ) )
self . linear_proje cti ons = nn . ModuleList ( self . linear_proje cti ons)
for name , p in self . named_parameters ( ) :
if " weight " in name : nn . init . kaiming_normal_ ( p )
if " bias " in name : nn . init . constant_ ( p , 0 )
self . project_channels = nn . Conv2d ( in_channels = ( layers_count + 1 ) * hidden_dim , out_channels = upscale_factor * upscale_factor , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True )
self . shuffle = nn . PixelShuffle ( upscale_factor )
def forward ( self , x ) :
x = ( x - 127.5 ) / 127.5
x = torch . relu ( self . embed ( x ) )
for conv in self . convs :
x = torch . cat ( [ x , torch . relu ( conv ( x ) ) ] , dim = 1 )
x = self . shuffle ( self . project_channels ( x ) )
x = torch . tanh ( x )
x = x * 127.5 + 127.5
x = round_func ( x )
return x
class ConvUpscaleBlock ( nn . Module ) :
def __init__ ( self , hidden_dim = 32 , layers_count = 5 , upscale_factor = 1 ) :
super ( ConvUpscaleBlock , self ) . __init__ ( )
assert layers_count > 0
self . upscale_factor = upscale_factor
self . hidden_dim = hidden_dim
self . embed = nn . Conv2d ( 1 , hidden_dim , kernel_size = ( 2 , 2 ) , padding = ' valid ' , stride = 1 , dilation = 1 , bias = True )
self . convs = [ ]
for i in range ( layers_count ) :
self . convs . append ( nn . Conv2d ( in_channels = hidden_dim , out_channels = hidden_dim , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True ) )
self . convs = nn . ModuleList ( self . convs )
for name , p in self . named_parameters ( ) :
if " weight " in name : nn . init . kaiming_normal_ ( p )
if " bias " in name : nn . init . constant_ ( p , 0 )
self . project_channels = nn . Conv2d ( in_channels = hidden_dim , out_channels = upscale_factor * upscale_factor , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True )
self . shuffle = nn . PixelShuffle ( upscale_factor )
self . project_channels = nn . Linear ( in_features = ( layers_count + 1 ) * hidden_dim , out_features = upscale_factor * upscale_factor , bias = True )
def forward ( self , x ) :
b , c , h , w = x . shape
x = ( x - 127.5 ) / 127.5
x = self . percieve_pattern ( x )
x = torch . relu ( self . embed ( x ) )
for conv in self . conv s:
x = torch . relu( conv ( x ) )
x = self . shuffle( self . project_channels( x ) )
for linear_projection in self . linear_projections :
x = torch . cat ( [ x , torch . relu ( linear_projection ( x ) ) ] , dim = 2 )
x = self . project_channels ( x )
x = torch . tanh ( x )
x = x * 127.5 + 127.5
x = round_func ( x )
x = x . reshape ( b , c , h , w , self . upscale_factor , self . upscale_factor )
x = x . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = x . reshape ( b , c , h * self . upscale_factor , w * self . upscale_factor )
return x
# https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105