HDBNet (HKLUT) implementation

main
protsenkovi 2 months ago
parent debd2acf83
commit cf7bab78af

@ -12,7 +12,7 @@ class SRBase(nn.Module):
def get_loss_fn(self): def get_loss_fn(self):
def loss_fn(pred, target): def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255) return F.mse_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn return loss_fn
# def get_loss_fn(self): # def get_loss_fn(self):

@ -65,7 +65,6 @@ class SRTrainDataset(Dataset):
: :
] ]
if self.rigid_aug: if self.rigid_aug:
if random.uniform(0, 1) < 0.5: if random.uniform(0, 1) < 0.5:
hr_patch = np.fliplr(hr_patch) hr_patch = np.fliplr(hr_patch)
@ -78,9 +77,9 @@ class SRTrainDataset(Dataset):
k = random.choice([0, 1, 2, 3]) k = random.choice([0, 1, 2, 3])
hr_patch = np.rot90(hr_patch, k) hr_patch = np.rot90(hr_patch, k)
lr_patch = np.rot90(lr_patch, k) lr_patch = np.rot90(lr_patch, k)
hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32)
lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32)
hr_patch = hr_patch.permute(2,0,1) hr_patch = hr_patch.permute(2,0,1)
lr_patch = lr_patch.permute(2,0,1) lr_patch = lr_patch.permute(2,0,1)

@ -44,7 +44,7 @@ class ChebyKANBase(SRBase):
super(ChebyKANBase, self).__init__() super(ChebyKANBase, self).__init__()
self.config = None self.config = None
self.stage1_S = layers.UpscaleBlock(None) self.stage1_S = layers.UpscaleBlock(None)
window_size = 7 window_size = 5
self._extract_pattern = layers.PercievePattern( self._extract_pattern = layers.PercievePattern(
receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)], receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)],
center=[window_size//2,window_size//2], center=[window_size//2,window_size//2],
@ -62,21 +62,21 @@ class ChebyKANNet(ChebyKANBase):
def __init__(self, config): def __init__(self, config):
super(ChebyKANNet, self).__init__() super(ChebyKANNet, self).__init__()
self.config = config self.config = config
window_size = 7 window_size = 5
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet( self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size, in_features=window_size*window_size,
out_channels=1, out_channels=1,
hidden_dim=16, hidden_dim=16,
layers_count=self.config.layers_count, layers_count=2,#self.config.layers_count,
upscale_factor=self.config.upscale_factor, upscale_factor=self.config.upscale_factor,
degree=8 degree=3
) )
class ChebyKANLut(ChebyKANBase): class ChebyKANLut(ChebyKANBase):
def __init__(self, config): def __init__(self, config):
super(ChebyKANLut, self).__init__() super(ChebyKANLut, self).__init__()
self.config = config self.config = config
window_size = 7 window_size = 5
self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet( self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
in_features=window_size*window_size, in_features=window_size*window_size,
out_channels=1, out_channels=1,
@ -92,4 +92,72 @@ class ChebyKANLut(ChebyKANBase):
return ssim_loss(pred, target) + l1_loss(pred, target) return ssim_loss(pred, target) + l1_loss(pred, target)
return loss_fn return loss_fn
TRANSFERER.register(ChebyKANNet, ChebyKANLut) TRANSFERER.register(ChebyKANNet, ChebyKANLut)
class HDBNetBase(SRBase):
def __init__(self):
super(HDBNetBase, self).__init__()
self.config = None
self.stage_3H = layers.UpscaleBlock(None)
self.stage_3D = layers.UpscaleBlock(None)
self.stage_3B = layers.UpscaleBlock(None)
self.stage_2H = layers.UpscaleBlock(None)
self.stage_2D = layers.UpscaleBlock(None)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \
self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \
self.stage_3B( rotated_msb, self._extract_pattern_3B )
output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \
self.stage_2D( rotated_lsb, self._extract_pattern_2D )
output_msb /= 3
output_lsb /= 2
if not script_config is None and script_config.current_iter % script_config.display_step == 0:
script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter)
script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= 4
x = output
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.l1_loss(pred/127.5-127.5, target/127.5-127.5)
return loss_fn
class HDBNet(HDBNetBase):
def __init__(self, config):
super(HDBNet, self).__init__()
self.config = config
self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class HDBLut(HDBNetBase):
def __init__(self, config):
super(HDBLut, self).__init__()
self.config = config
self.stage_3H.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
self.stage_3D.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
self.stage_3B.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
self.stage_2H.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
self.stage_2D.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
TRANSFERER.register(HDBNet, HDBLut)

@ -242,7 +242,6 @@ if __name__ == "__main__":
if i % config.display_step == 0: if i % config.display_step == 0:
config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i) config.writer.add_scalar('loss_Pixel', l_accum[0] / config.display_step, i)
config.writer.add_scalar('loss', loss.item(), i) config.writer.add_scalar('loss', loss.item(), i)
config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format( config.logger.info("{} | Iter:{:6d}, Sample:{:6d}, loss:{:.2e}, prepare_data_time:{:.4f}, forward_backward_time:{:.4f}".format(
model.__class__.__name__, model.__class__.__name__,
i, i,

Loading…
Cancel
Save