@ -4,49 +4,17 @@ import torch.nn.functional as F
import numpy as np
from common . utils import round_func
from common import lut
from common . layers import PercievePattern
from common . layers import PercievePattern , DenseConvUpscaleBlock
from pathlib import Path
from . import sdylut
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С . 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock ( nn . Module ) :
def __init__ ( self , hidden_dim = 32 , layers_count = 5 , upscale_factor = 1 ) :
super ( DenseConvUpscaleBlock , self ) . __init__ ( )
assert layers_count > 0
self . upscale_factor = upscale_factor
self . hidden_dim = hidden_dim
self . percieve = nn . Conv2d ( 1 , hidden_dim , kernel_size = ( 2 , 2 ) , padding = ' valid ' , stride = 1 , dilation = 1 , bias = True )
self . convs = [ ]
for i in range ( layers_count ) :
self . convs . append ( nn . Conv2d ( in_channels = ( i + 1 ) * hidden_dim , out_channels = hidden_dim , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True ) )
self . convs = nn . ModuleList ( self . convs )
for name , p in self . named_parameters ( ) :
if " weight " in name : nn . init . kaiming_normal_ ( p )
if " bias " in name : nn . init . constant_ ( p , 0 )
self . project_channels = nn . Conv2d ( in_channels = ( layers_count + 1 ) * hidden_dim , out_channels = upscale_factor * upscale_factor , kernel_size = 1 , stride = 1 , padding = 0 , dilation = 1 , bias = True )
self . shuffle = nn . PixelShuffle ( upscale_factor )
def forward ( self , x ) :
x = ( x - 127.5 ) / 127.5
x = torch . relu ( self . percieve ( x ) )
for conv in self . convs :
x = torch . cat ( [ x , torch . relu ( conv ( x ) ) ] , dim = 1 )
x = self . shuffle ( self . project_channels ( x ) )
x = torch . tanh ( x )
x = round_func ( x * 127.5 + 127.5 )
return x
class SDYNetx1 ( nn . Module ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetx1 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] )
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . stageS = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stageD = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stageY = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
@ -90,9 +58,9 @@ class SDYNetx2(nn.Module):
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetx2 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] )
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . stageS_1 = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stageD_1 = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
self . stageY_1 = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 1 )
@ -160,3 +128,50 @@ class SDYNetx2(nn.Module):
stageY_2 = lut . transfer_2x2_input_SxS_output ( self . stageY_2 , quantization_interval = quantization_interval , batch_size = batch_size )
lut_model = sdylut . SDYLutx2 . init_from_lut ( stageS_1 , stageD_1 , stageY_1 , stageS_2 , stageD_2 , stageY_2 )
return lut_model
class SDYNetCenteredx1 ( nn . Module ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( SDYNetCenteredx1 , self ) . __init__ ( )
self . scale = scale
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 1 , 1 ] , window_size = 3 )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 1 , 1 ] , window_size = 3 )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 1 , 1 ] , window_size = 3 )
self . stageS = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stageD = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
self . stageY = DenseConvUpscaleBlock ( hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = scale )
def forward ( self , x ) :
b , c , h , w = x . shape
x = x . view ( b * c , 1 , h , w )
output = torch . zeros ( [ b * c , 1 , h * self . scale , w * self . scale ] , dtype = x . dtype , device = x . device )
for rotations_count in range ( 4 ) :
rotated = torch . rot90 ( x , k = rotations_count , dims = [ - 2 , - 1 ] )
rb , rc , rh , rw = rotated . shape
s = self . stageS ( self . _extract_pattern_S ( rotated ) )
s = s . view ( rb * rc , 1 , rh , rw , self . scale , self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) . reshape ( rb * rc , 1 , rh * self . scale , rw * self . scale )
s = torch . rot90 ( s , k = - rotations_count , dims = [ - 2 , - 1 ] )
output + = s
d = self . stageD ( self . _extract_pattern_D ( rotated ) )
d = d . view ( rb * rc , 1 , rh , rw , self . scale , self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) . reshape ( rb * rc , 1 , rh * self . scale , rw * self . scale )
d = torch . rot90 ( d , k = - rotations_count , dims = [ - 2 , - 1 ] )
output + = d
y = self . stageY ( self . _extract_pattern_Y ( rotated ) )
y = y . view ( rb * rc , 1 , rh , rw , self . scale , self . scale ) . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) . reshape ( rb * rc , 1 , rh * self . scale , rw * self . scale )
y = torch . rot90 ( y , k = - rotations_count , dims = [ - 2 , - 1 ] )
output + = y
output / = 4 * 3
output = output . view ( b , c , h * self . scale , w * self . scale )
return output
def get_lut_model ( self , quantization_interval = 16 , batch_size = 2 * * 10 ) :
stageS = lut . transfer_2x2_input_SxS_output ( self . stageS , quantization_interval = quantization_interval , batch_size = batch_size )
stageD = lut . transfer_2x2_input_SxS_output ( self . stageD , quantization_interval = quantization_interval , batch_size = batch_size )
stageY = lut . transfer_2x2_input_SxS_output ( self . stageY , quantization_interval = quantization_interval , batch_size = batch_size )
lut_model = sdylut . SDYLutCenteredx1 . init_from_lut ( stageS , stageD , stageY )
return lut_model