main
protsenkovi 5 months ago
parent 8257ec7e57
commit 2218057822

@ -44,9 +44,9 @@ class UpscaleBlock(nn.Module):
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = torch.relu(self.embed(x))
x = torch.nn.functional.gelu(self.embed(x))
for linear_projection in self.linear_projections:
x = torch.cat([x, torch.relu(linear_projection(x))], dim=2)
x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2)
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias

@ -5,6 +5,7 @@ from . import srlut
from . import sdynet
from . import hdbnet
from . import hdblut
from common import losses
import torch
import numpy as np
from pathlib import Path
@ -32,6 +33,8 @@ AVAILABLE_MODELS = {
'HDBLut': hdblut.HDBLut,
'HDBLNet': hdbnet.HDBLNet,
'HDBHNet': hdbnet.HDBHNet,
'SRMsbLsbNet': srnet.SRMsbLsbNet,
'SRMsbLsbShift2Net': srnet.SRMsbLsbShift2Net,
'SRMsbLsbR90Net': srnet.SRMsbLsbR90Net,
'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net,
# 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,

@ -18,17 +18,17 @@ class HDBLut(nn.Module):
self.scale = scale
self.quantization_interval = quantization_interval
self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (2,2)).type(torch.float32))
self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
@ -41,9 +41,9 @@ class HDBLut(nn.Module):
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
):
scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = HDBLut(quantization_interval=quantization_interval, scale=scale)
# quantization_interval = 256//(stage1_3H.shape[0]-1)
quantization_interval = 16
lut_model = HDBLut(quantization_interval=quantization_interval, scale=4)
lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32))
lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32))
lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32))
@ -59,15 +59,26 @@ class HDBLut(nn.Module):
def forward_stage(self, x, scale, percieve_pattern, lut):
b,c,h,w = x.shape
print(np.prod(x.shape))
x = percieve_pattern(x)
x = select_index_4dlut_tetrahedral(index=x, lut=lut)
shifts = torch.tensor([lut.shape[0]**d for d in range(len(lut.shape)-2)], device=x.device).flip(0).reshape(1,1,len(lut.shape)-2)
print(x.shape, x.min(), x.max())
x = torch.sum(x * shifts, dim=-1)
print(x.shape)
lut = torch.clamp(lut, 0, 255)
lut = lut.reshape(-1, scale, scale)
x = x.flatten().type(torch.int64)
x = lut[x]
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
print(x.shape)
# raise RuntimeError
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
@ -75,7 +86,7 @@ class HDBLut(nn.Module):
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
@ -84,14 +95,18 @@ class HDBLut(nn.Module):
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
print(output_msb.min(), output_msb.max())
print(output_lsb.min(), output_lsb.max())
output_msb = output_msb + output_lsb
x = output_msb
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
print("STAGE2", msb.min(), msb.max(), lsb.min(), lsb.max())
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
@ -106,4 +121,14 @@ class HDBLut(nn.Module):
return x
def __repr__(self):
return f"{self.__class__.__name__}\n lut size: {self.stage_lut.shape}"
return f"{self.__class__.__name__}" + \
f"\n stage1_3H size: {self.stage1_3H.shape}" + \
f"\n stage1_3D size: {self.stage1_3D.shape}" + \
f"\n stage1_3B size: {self.stage1_3B.shape}" + \
f"\n stage1_2H size: {self.stage1_2H.shape}" + \
f"\n stage1_2D size: {self.stage1_2D.shape}" + \
f"\n stage2_3H size: {self.stage2_3H.shape}" + \
f"\n stage2_3D size: {self.stage2_3D.shape}" + \
f"\n stage2_3B size: {self.stage2_3B.shape}" + \
f"\n stage2_2H size: {self.stage2_2H.shape}" + \
f"\n stage2_2D size: {self.stage2_2D.shape}"

@ -5,25 +5,26 @@ import numpy as np
from common.utils import round_func
from common import lut
from pathlib import Path
# from . import srlut
from . import hdblut
from common import layers
from itertools import cycle
class HDBNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBNet, self).__init__()
assert scale == 4
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2*2)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2*2)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2*2)
self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2*2)
self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2*2)
# self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
# self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
# self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
# self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
# self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
@ -41,59 +42,213 @@ class HDBNet(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
output_msb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2*2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2*2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2*2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2*2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2*2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
output_msb = round_func((output_msb / 255) * 16) * 15
output_lsb = (output_lsb / 255) * 15
# print(output_msb.min(), output_msb.max(), output_lsb.min(), output_lsb.max())
x = output_msb + output_lsb
# lsb = x % 16
# msb = x - lsb
# output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
# output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
# for rotations_count in range(4):
# rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
# rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
# output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
# output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
# output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
# output_msb /= 4*3
# output_lsb /= 4*2
# output_msb = round_func((output_msb / 255) * 16) * 15
# output_lsb = (output_lsb / 255) * 15
# # print(output_msb.min(), output_msb.max(), output_lsb.min(), output_lsb.max())
# x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class HDBLNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBLNet, self).__init__()
self.scale = scale
self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
output_msb += torch.rot90(self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L), k=-rotations_count, dims=[2, 3])
output_msb /= 4*3
output_lsb /= 4*2
output_msb = output_msb + output_lsb
x = output_msb
output_lsb /= 4
output_msb = round_func((output_msb / 255) * 16) * 15
output_lsb = (output_lsb / 255) * 15
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
# def get_lut_model(self, quantization_interval=16, batch_size=2**10):
# stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
# stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
# lut_model = hdblut.HDBLut.init_from_numpy(
# stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
# stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
# )
# return lut_model
class HDBHNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(HDBHNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
) for x in range(1)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func((output_msb_r / 255)*16) * 15
output_lsb_r = (output_lsb_r / 255) * 15
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage1_3H = lut.transfer_2x2_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3D = lut.transfer_2x2_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_3B = lut.transfer_2x2_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2H = lut.transfer_2x2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_2D = lut.transfer_2x2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3H = lut.transfer_2x2_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3D = lut.transfer_2x2_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_3B = lut.transfer_2x2_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2H = lut.transfer_2x2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_2D = lut.transfer_2x2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
raise NotImplementedError
lut_model = hdblut.HDBLut.init_from_numpy(
stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D,
stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
)
return lut_model
def get_loss_fn(self):
fourier_loss_fn = FocalFrequencyLoss()
high_frequency_loss_fn = FourierLoss()
def loss_fn(pred, target):
a = fourier_loss_fn(pred/255, target/255) * 1e8
# b = F.mse_loss(pred/255, target/255) #* 1e3
# c = high_frequency_loss_fn(pred/255, target/255) * 1e6
return a #+ b #+ c
return loss_fn

@ -9,16 +9,9 @@ from . import srlut
from common import layers
from common import losses
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
class SRNetBase(nn.Module):
def __init__(self):
super(SRNetBase, self).__init__()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
@ -30,6 +23,26 @@ class SRNet(nn.Module):
x = x.reshape(b, c, h*scale, w*scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self.stage1_S = layers.UpscaleBlock(
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
@ -42,12 +55,8 @@ class SRNet(nn.Module):
lut_model = srlut.SRLut.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetY(nn.Module):
class SRNetY(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetY, self).__init__()
self.scale = scale
@ -60,16 +69,6 @@ class SRNetY(nn.Module):
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
@ -88,12 +87,8 @@ class SRNetY(nn.Module):
lut_model = srlut.SRLutY.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90(nn.Module):
class SRNetR90(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90, self).__init__()
self.scale = scale
@ -104,16 +99,6 @@ class SRNetR90(nn.Module):
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
@ -131,12 +116,8 @@ class SRNetR90(nn.Module):
lut_model = srlut.SRLutR90.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRNetR90Y(nn.Module):
class SRNetR90Y(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetR90Y, self).__init__()
self.scale = scale
@ -150,16 +131,6 @@ class SRNetR90Y(nn.Module):
self.rgb_to_ycbcr = layers.RgbToYcbcr()
self.ycbcr_to_rgb = layers.YcbcrToRgb()
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = self.rgb_to_ycbcr(x)
@ -183,15 +154,10 @@ class SRNetR90Y(nn.Module):
lut_model = srlut.SRLutR90Y.init_from_numpy(stage_lut)
return lut_model
def get_loss_fn(self):
def loss_fn(pred, target):
return F.mse_loss(pred/255, target/255)
return loss_fn
class SRMsbLsbR90Net(nn.Module):
class SRMsbLsbNet(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbR90Net, self).__init__()
super(SRMsbLsbNet, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
@ -202,7 +168,7 @@ class SRMsbLsbR90Net(nn.Module):
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=15
output_max_value=255
)
self.lsb_fn = layers.UpscaleBlock(
in_features=4,
@ -210,20 +176,110 @@ class SRMsbLsbR90Net(nn.Module):
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=15
output_max_value=255
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
class SRMsbLsbShift2Net(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbShift2Net, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.count = 4
self.msb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=255
) for x in range(self.count)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=255
) for x in range(self.count)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for i, msb_fn, lsb_fn in zip(range(self.count), self.msb_fns, self.lsb_fns):
output_msb_s = self.forward_stage(msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_s = self.forward_stage(lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb += torch.nn.functional.pad(output_msb_s, [i, 0, i, 0], mode='replicate')[:,:,:h*self.scale,:w*self.scale]
output_lsb += torch.nn.functional.pad(output_lsb_s, [i, 0, i, 0], mode='replicate')[:,:,:h*self.scale,:w*self.scale]
output_msb /= self.count
output_lsb /= self.count
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
class SRMsbLsbR90Net(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsbR90Net, self).__init__()
self.scale = scale
self.hidden_dim = hidden_dim
self.layers_count = layers_count
self.msb_fn = layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=255
)
self.lsb_fn = layers.UpscaleBlock(
in_features=4,
hidden_dim=hidden_dim,
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=255
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
@ -236,10 +292,8 @@ class SRMsbLsbR90Net(nn.Module):
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func(output_msb_r) * 15
output_lsb_r = round_func(output_lsb_r)
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
@ -254,14 +308,8 @@ class SRMsbLsbR90Net(nn.Module):
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = losses.FocalFrequencyLoss()
def loss_fn(pred, target):
return fourier_loss_fn(pred, target)
return loss_fn
class SRMsbLsb4R90Net(nn.Module):
class SRMsbLsb4R90Net(SRNetBase):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRMsbLsb4R90Net, self).__init__()
self.scale = scale
@ -274,7 +322,7 @@ class SRMsbLsb4R90Net(nn.Module):
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=255,
output_max_value=15
output_max_value=255
) for x in range(4)])
self.lsb_fns = nn.ModuleList([layers.UpscaleBlock(
in_features=4,
@ -282,20 +330,10 @@ class SRMsbLsb4R90Net(nn.Module):
layers_count=layers_count,
upscale_factor=self.scale,
input_max_value=15,
output_max_value=15
output_max_value=255
) for x in range(4)])
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
def forward_stage(self, x, scale, percieve_pattern, stage):
b,c,h,w = x.shape
x = percieve_pattern(x)
x = stage(x)
x = round_func(x)
x = x.reshape(b, c, h, w, scale, scale)
x = x.permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*scale, w*scale)
return x
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
@ -310,8 +348,6 @@ class SRMsbLsb4R90Net(nn.Module):
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
output_msb_r = round_func(output_msb_r) * 15
output_lsb_r = round_func(output_lsb_r)
output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
output_msb /= 4
@ -320,14 +356,9 @@ class SRMsbLsb4R90Net(nn.Module):
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
x = output_msb + output_lsb
x = x.clamp(0, 255)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
raise NotImplementedError
def get_loss_fn(self):
fourier_loss_fn = losses.FocalFrequencyLoss()
def loss_fn(pred, target):
return fourier_loss_fn(pred, target)
return loss_fn

@ -98,6 +98,7 @@ def prepare_experiment_folder(config):
config.logs_dir.mkdir()
if __name__ == "__main__":
# torch.set_float32_matmul_precision('high')
script_start_time = datetime.now()
config_inst = TrainOptions()
@ -112,7 +113,8 @@ if __name__ == "__main__":
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters(), lr=1e-2, betas=(0.9, 0.95))
# model = torch.compile(model)
optimizer = AdamWScheduleFree(model.parameters(), betas=(0.9, 0.95))
print(optimizer)
prepare_experiment_folder(config)

Loading…
Cancel
Save