Refactor DenseConvUpscaleBlock, explicit window_size for PercievePattern

main
Vladimir Protsenko 8 months ago
parent b1f2f6d76b
commit 6240dd05fc

@ -4,9 +4,11 @@ import torch.nn.functional as F
import numpy as np import numpy as np
class PercievePattern(): class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]): def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2):
assert window_size >= (np.max(receptive_field_idxes)+1)
self.receptive_field_idxes = np.array(receptive_field_idxes) self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = np.max(self.receptive_field_idxes) + 1 self.window_size = window_size
self.center = center
self.receptive_field_idxes = [ self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1], self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1], self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
@ -16,7 +18,8 @@ class PercievePattern():
def __call__(self, x): def __call__(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') x = F.pad(x, pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1], mode='replicate')
x = F.unfold(input=x, kernel_size=self.window_size) x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([ x = torch.stack([
x[:,self.receptive_field_idxes[0],:], x[:,self.receptive_field_idxes[0],:],
@ -25,4 +28,36 @@ class PercievePattern():
x[:,self.receptive_field_idxes[3],:] x[:,self.receptive_field_idxes[3],:]
], 2) ], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2) x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x return x
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x

@ -6,38 +6,7 @@ from common.utils import round_func
from pathlib import Path from pathlib import Path
from common import lut from common import lut
from . import rclut from . import rclut
from common.layers import DenseConvUpscaleBlock
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class ReconstructedConvCentered(nn.Module): class ReconstructedConvCentered(nn.Module):
def __init__(self, hidden_dim, window_size=7): def __init__(self, hidden_dim, window_size=7):

@ -4,49 +4,17 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from common.utils import round_func from common.utils import round_func
from common import lut from common import lut
from common.layers import PercievePattern from common.layers import PercievePattern, DenseConvUpscaleBlock
from pathlib import Path from pathlib import Path
from . import sdylut from . import sdylut
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class SDYNetx1(nn.Module): class SDYNetx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx1, self).__init__() super(SDYNetx1, self).__init__()
self.scale = scale self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stageS = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageD = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stageD = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageY = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stageY = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
@ -90,9 +58,9 @@ class SDYNetx2(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetx2, self).__init__() super(SDYNetx2, self).__init__()
self.scale = scale self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
@ -159,4 +127,51 @@ class SDYNetx2(nn.Module):
stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size) stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size) stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2) lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2)
return lut_model
class SDYNetCenteredx1(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SDYNetCenteredx1, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[1,1], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[1,1], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[1,1], window_size=3)
self.stageS = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageD = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageY = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = self.stageS(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = self.stageD(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = self.stageY(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale)
return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stageS = lut.transfer_2x2_input_SxS_output(self.stageS, quantization_interval=quantization_interval, batch_size=batch_size)
stageD = lut.transfer_2x2_input_SxS_output(self.stageD, quantization_interval=quantization_interval, batch_size=batch_size)
stageY = lut.transfer_2x2_input_SxS_output(self.stageY, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutCenteredx1.init_from_lut(stageS, stageD, stageY)
return lut_model return lut_model

@ -6,38 +6,7 @@ from common.utils import round_func
from common import lut from common import lut
from pathlib import Path from pathlib import Path
from .srlut import SRLut, SRLutRot90 from .srlut import SRLut, SRLutRot90
from common.layers import DenseConvUpscaleBlock
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x
class SRNet(nn.Module): class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):

Loading…
Cancel
Save