You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
6.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .utils import round_func
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2):
assert window_size >= (np.max(receptive_field_idxes)+1)
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = window_size
self.center = center
self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(
x,
pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1],
mode='replicate'
)
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
# refactoring to linear give slight speed up, but require total rewrite to be consistent
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = (layers_count+1)*hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.embed(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = x*127.5 + 127.5
x = round_func(x)
return x
class ConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(ConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.embed(x))
for conv in self.convs:
x = torch.relu(conv(x))
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = x*127.5 + 127.5
x = round_func(x)
return x
class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr.
The image data is assumed to be in the range of (0, 1).
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
"""
def forward(self, image):
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
delta = 0.5
y = 0.299 * r + 0.587 * g + 0.114 * b
cb = (b - y) * 0.564 + delta
cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
"""
def forward(self, image):
y = image[..., 0, :, :]
cb = image[..., 1, :, :]
cr = image[..., 2, :, :]
delta = 0.5
cb_shifted = cb - delta
cr_shifted = cr - delta
r = y + 1.403 * cr_shifted
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)