lut_reproduce/src/common/layers.py

63 lines
3.1 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2):
assert window_size >= (np.max(receptive_field_idxes)+1)
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = window_size
self.center = center
self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(x, pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1], mode='replicate')
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
return x