From b63d0f98dfc7ae0c6be836bcdca9d796d6ff0617 Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 10:54:58 +0000 Subject: [PATCH] use one dense backbone for all nets with linear layers --- src/common/layers.py | 74 ++++++++++++-------------------------------- 1 file changed, 20 insertions(+), 54 deletions(-) diff --git a/src/common/layers.py b/src/common/layers.py index 6f4bca3..8c88d31 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -7,6 +7,7 @@ from .utils import round_func class PercievePattern(): def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2): assert window_size >= (np.max(receptive_field_idxes)+1) + assert len(receptive_field_idxes) == 4 self.receptive_field_idxes = np.array(receptive_field_idxes) self.window_size = window_size self.center = center @@ -32,74 +33,39 @@ class PercievePattern(): x[:,self.receptive_field_idxes[2],:], x[:,self.receptive_field_idxes[3],:] ], 2) - x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2) return x -# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С. 4700-4708. -# https://ar5iv.labs.arxiv.org/html/1608.06993 -# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 -# refactoring to linear give slight speed up, but require total rewrite to be consistent -class DenseConvUpscaleBlock(nn.Module): - def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): - super(DenseConvUpscaleBlock, self).__init__() +class UpscaleBlock(nn.Module): + def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, in_features=4, hidden_dim = 32, layers_count=5, upscale_factor=1): + super(UpscaleBlock, self).__init__() assert layers_count > 0 + self.percieve_pattern = PercievePattern(receptive_field_idxes=receptive_field_idxes, center=center, window_size=window_size) self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim - self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) - self.convs = [] + self.linear_projections = [] for i in range(layers_count): - self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) - self.convs = nn.ModuleList(self.convs) + self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True)) + self.linear_projections = nn.ModuleList(self.linear_projections) - for name, p in self.named_parameters(): - if "weight" in name: nn.init.kaiming_normal_(p) - if "bias" in name: nn.init.constant_(p, 0) - - self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) - self.shuffle = nn.PixelShuffle(upscale_factor) + self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) def forward(self, x): + b,c,h,w = x.shape x = (x-127.5)/127.5 + x = self.percieve_pattern(x) x = torch.relu(self.embed(x)) - for conv in self.convs: - x = torch.cat([x, torch.relu(conv(x))], dim=1) - x = self.shuffle(self.project_channels(x)) + for linear_projection in self.linear_projections: + x = torch.cat([x, torch.relu(linear_projection(x))], dim=2) + x = self.project_channels(x) x = torch.tanh(x) x = x*127.5 + 127.5 x = round_func(x) - return x - -class ConvUpscaleBlock(nn.Module): - def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): - super(ConvUpscaleBlock, self).__init__() - assert layers_count > 0 - self.upscale_factor = upscale_factor - self.hidden_dim = hidden_dim - self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) - - self.convs = [] - for i in range(layers_count): - self.convs.append(nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) - self.convs = nn.ModuleList(self.convs) - - for name, p in self.named_parameters(): - if "weight" in name: nn.init.kaiming_normal_(p) - if "bias" in name: nn.init.constant_(p, 0) - - self.project_channels = nn.Conv2d(in_channels = hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) - self.shuffle = nn.PixelShuffle(upscale_factor) - - def forward(self, x): - x = (x-127.5)/127.5 - x = torch.relu(self.embed(x)) - for conv in self.convs: - x = torch.relu(conv(x)) - x = self.shuffle(self.project_channels(x)) - x = torch.tanh(x) - x = x*127.5 + 127.5 - x = round_func(x) - return x + x = x.reshape(b, c, h, w, self.upscale_factor, self.upscale_factor) + x = x.permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*self.upscale_factor, w*self.upscale_factor) + return x # https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105 class RgbToYcbcr(nn.Module): @@ -161,4 +127,4 @@ class YcbcrToRgb(nn.Module): r = y + 1.403 * cr_shifted g = y - 0.714 * cr_shifted - 0.344 * cb_shifted b = y + 1.773 * cb_shifted - return torch.stack([r, g, b], -3) + return torch.stack([r, g, b], -3) \ No newline at end of file