@ -8,23 +8,24 @@ from pathlib import Path
from . import hdblut
from . import hdblut
from common import layers
from common import layers
from itertools import cycle
from itertools import cycle
from models . base import SRNetBase
class HDBNet ( nn. Modul e) :
class HDBNet ( SRNetBas e) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
def __init__ ( self , hidden_dim = 64 , layers_count = 4 , scale = 4 ) :
super ( HDBNet , self ) . __init__ ( )
super ( HDBNet , self ) . __init__ ( )
assert scale == 4
assert scale == 4
self . scale = scale
self . scale = scale
self . stage1_3H = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 * 2 )
self . stage1_3H = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
self . stage1_3D = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 * 2 )
self . stage1_3D = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
self . stage1_3B = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 * 2 )
self . stage1_3B = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
self . stage1_2H = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 * 2 )
self . stage1_2H = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 15 , output_max_value = 255 )
self . stage1_2D = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 * 2 )
self . stage1_2D = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 15 , output_max_value = 255 )
# self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2 )
self . stage2_3H = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
# self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2 )
self . stage2_3D = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
# self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2 )
self . stage2_3B = layers . UpscaleBlock ( in_features = 3 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 255 , output_max_value = 255 )
# self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2 )
self . stage2_2H = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 15 , output_max_value = 255 )
# self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2 )
self . stage2_2D = layers . UpscaleBlock ( in_features = 2 , hidden_dim = hidden_dim , layers_count = layers_count , upscale_factor = 2 , input_max_value = 15 , output_max_value = 255 )
self . _extract_pattern_3H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 0 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_3H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 0 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_3D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_3D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
@ -32,6 +33,8 @@ class HDBNet(nn.Module):
self . _extract_pattern_2H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_2H = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_2D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . _extract_pattern_2D = layers . PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 2 )
self . rotations = 4
def forward_stage ( self , x , scale , percieve_pattern , stage ) :
def forward_stage ( self , x , scale , percieve_pattern , stage ) :
b , c , h , w = x . shape
b , c , h , w = x . shape
x = percieve_pattern ( x )
x = percieve_pattern ( x )
@ -41,47 +44,48 @@ class HDBNet(nn.Module):
x = x . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = x . permute ( 0 , 1 , 2 , 4 , 3 , 5 )
x = x . reshape ( b , c , h * scale , w * scale )
x = x . reshape ( b , c , h * scale , w * scale )
return x
return x
def forward ( self , x , config = None ) :
def forward ( self , x , config = None ) :
b , c , h , w = x . shape
b , c , h , w = x . shape
x = x . reshape ( b * c , 1 , h , w )
x = x . reshape ( b * c , 1 , h , w )
lsb = x % 16
lsb = x % 16
msb = x - lsb
msb = x - lsb
output_msb = torch . zeros ( [ b * c , 1 , h * 2 * 2 , w * 2 * 2 ] , dtype = x . dtype , device = x . device )
output = torch . zeros ( [ b * c , 1 , h * 2 , w * 2 ] , dtype = x . dtype , device = x . device )
output_lsb = torch . zeros ( [ b * c , 1 , h * 2 * 2 , w * 2 * 2 ] , dtype = x . dtype , device = x . device )
for rotations_count in range ( self . rotations ) :
for rotations_count in range ( 4 ) :
rotated_msb = torch . rot90 ( msb , k = rotations_count , dims = [ 2 , 3 ] )
rotated_msb = torch . rot90 ( msb , k = rotations_count , dims = [ 2 , 3 ] )
rotated_lsb = torch . rot90 ( lsb , k = rotations_count , dims = [ 2 , 3 ] )
rotated_lsb = torch . rot90 ( lsb , k = rotations_count , dims = [ 2 , 3 ] )
output_msb + = torch . rot90 ( self . forward_stage ( rotated_msb , 2 * 2 , self . _extract_pattern_3H , self . stage1_3H ) , k = - rotations_count , dims = [ 2 , 3 ] )
output_msb = self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3H , self . stage1_3H ) + \
output_msb + = torch . rot90 ( self . forward_stage ( rotated_msb , 2 * 2 , self . _extract_pattern_3D , self . stage1_3D ) , k = - rotations_count , dims = [ 2 , 3 ] )
self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3D , self . stage1_3D ) + \
output_msb + = torch . rot90 ( self . forward_stage ( rotated_msb , 2 * 2 , self . _extract_pattern_3B , self . stage1_3B ) , k = - rotations_count , dims = [ 2 , 3 ] )
self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3B , self . stage1_3B )
output_lsb + = torch . rot90 ( self . forward_stage ( rotated_lsb , 2 * 2 , self . _extract_pattern_2H , self . stage1_2H ) , k = - rotations_count , dims = [ 2 , 3 ] )
output_msb / = 3
output_lsb + = torch . rot90 ( self . forward_stage ( rotated_lsb , 2 * 2 , self . _extract_pattern_2D , self . stage1_2D ) , k = - rotations_count , dims = [ 2 , 3 ] )
output_lsb = self . forward_stage ( rotated_lsb , 2 , self . _extract_pattern_2H , self . stage1_2H ) + \
output_msb / = 4 * 3
self . forward_stage ( rotated_lsb , 2 , self . _extract_pattern_2D , self . stage1_2D )
output_lsb / = 4 * 2
output_lsb / = 2
output_msb = round_func ( ( output_msb / 255 ) * 16 ) * 15
if not config is None and config . current_iter % config . display_step == 0 :
output_lsb = ( output_lsb / 255 ) * 15
config . writer . add_histogram ( ' s1_output_lsb ' , output_lsb . detach ( ) . cpu ( ) . numpy ( ) , config . current_iter )
# print(output_msb.min(), output_msb.max(), output_lsb.min(), output_lsb.max())
config . writer . add_histogram ( ' s1_output_msb ' , output_msb . detach ( ) . cpu ( ) . numpy ( ) , config . current_iter )
x = output_msb + output_lsb
output + = torch . rot90 ( output_msb + output_lsb , k = - rotations_count , dims = [ 2 , 3 ] ) . clamp ( 0 , 255 )
# lsb = x % 16
output / = self . rotations
# msb = x - lsb
x = output
lsb = x % 16
# output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
msb = x - lsb
# output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output = torch . zeros ( [ b * c , 1 , h * 4 , w * 4 ] , dtype = x . dtype , device = x . device )
# for rotations_count in range(4):
for rotations_count in range ( self . rotations ) :
# rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_msb = torch . rot90 ( msb , k = rotations_count , dims = [ 2 , 3 ] )
# rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch . rot90 ( lsb , k = rotations_count , dims = [ 2 , 3 ] )
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb = self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3H , self . stage2_3H ) + \
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3D , self . stage2_3D ) + \
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
self . forward_stage ( rotated_msb , 2 , self . _extract_pattern_3B , self . stage2_3B )
# output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
output_msb / = 3
# output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
output_lsb = self . forward_stage ( rotated_lsb , 2 , self . _extract_pattern_2H , self . stage2_2H ) + \
# output_msb /= 4*3
self . forward_stage ( rotated_lsb , 2 , self . _extract_pattern_2D , self . stage2_2D )
# output_lsb /= 4*2
output_lsb / = 2
# output_msb = round_func((output_msb / 255) * 16) * 15
if not config is None and config . current_iter % config . display_step == 0 :
# output_lsb = (output_lsb / 255) * 15
config . writer . add_histogram ( ' s2_output_lsb ' , output_lsb . detach ( ) . cpu ( ) . numpy ( ) , config . current_iter )
# # print(output_msb.min(), output_msb.max(), output_lsb.min(), output_lsb.max())
config . writer . add_histogram ( ' s2_output_msb ' , output_msb . detach ( ) . cpu ( ) . numpy ( ) , config . current_iter )
# x = output_msb + output_lsb
output + = torch . rot90 ( output_msb + output_lsb , k = - rotations_count , dims = [ 2 , 3 ] ) . clamp ( 0 , 255 )
output / = self . rotations
x = output
x = x . reshape ( b , c , h * self . scale , w * self . scale )
x = x . reshape ( b , c , h * self . scale , w * self . scale )
return x
return x