You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
529 lines
20 KiB
Python
529 lines
20 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from .utils import round_func
|
|
|
|
class PercievePattern():
|
|
"""
|
|
Coordinates scheme: [channel, height, width]. Channel can be ommited for all channels.
|
|
Examples:
|
|
1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]
|
|
2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]]
|
|
"""
|
|
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1):
|
|
assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1)
|
|
tmp = []
|
|
for coords in receptive_field_idxes:
|
|
if len(coords) < 3:
|
|
for i in range(channels):
|
|
tmp.append([i,] + coords)
|
|
if len(coords) == 3:
|
|
tmp.append(coords)
|
|
receptive_field_idxes = np.array(sorted(tmp))
|
|
self.window_size = window_size
|
|
self.center = center
|
|
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))]
|
|
assert len(np.unique(self.receptive_field_idxes) == len(self.receptive_field_idxes)), "Duplicated coordinates found. Coordinates scheme: [channel, height, width]."
|
|
|
|
def __call__(self, x):
|
|
b,c,h,w = x.shape
|
|
x = F.pad(
|
|
x,
|
|
pad=[
|
|
self.center[0], self.window_size-self.center[0]-1,
|
|
self.center[1], self.window_size-self.center[1]-1
|
|
],
|
|
mode='replicate'
|
|
)
|
|
x = F.unfold(input=x, kernel_size=self.window_size)
|
|
x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2)
|
|
return x
|
|
|
|
class UpscaleBlock(nn.Module):
|
|
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
|
|
super(UpscaleBlock, self).__init__()
|
|
assert layers_count > 0
|
|
self.upscale_factor = upscale_factor
|
|
self.hidden_dim = hidden_dim
|
|
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
|
|
|
|
self.linear_projections = []
|
|
for i in range(layers_count):
|
|
self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True))
|
|
self.linear_projections = nn.ModuleList(self.linear_projections)
|
|
|
|
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
|
|
|
|
self.in_bias = self.in_scale = input_max_value/2
|
|
self.out_bias = self.out_scale = output_max_value/2
|
|
|
|
def forward(self, x):
|
|
x = (x-self.in_bias)/self.in_scale
|
|
x = torch.nn.functional.gelu(self.embed(x))
|
|
for linear_projection in self.linear_projections:
|
|
x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2)
|
|
x = self.project_channels(x)
|
|
x = torch.tanh(x)
|
|
x = x*self.out_scale + self.out_bias
|
|
return x
|
|
|
|
class RgbToYcbcr(nn.Module):
|
|
r"""Convert an image from RGB to YCbCr.
|
|
|
|
The image data is assumed to be in the range of (0, 1).
|
|
|
|
Returns:
|
|
YCbCr version of the image.
|
|
|
|
Shape:
|
|
- image: :math:`(*, 3, H, W)`
|
|
- output: :math:`(*, 3, H, W)`
|
|
|
|
Examples:
|
|
>>> input = torch.rand(2, 3, 4, 5)
|
|
>>> ycbcr = RgbToYcbcr()
|
|
>>> output = ycbcr(input) # 2x3x4x5
|
|
|
|
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
|
|
"""
|
|
|
|
def forward(self, image):
|
|
r = image[..., 0, :, :]
|
|
g = image[..., 1, :, :]
|
|
b = image[..., 2, :, :]
|
|
|
|
delta = 0.5
|
|
y = 0.299 * r + 0.587 * g + 0.114 * b
|
|
cb = (b - y) * 0.564 + delta
|
|
cr = (r - y) * 0.713 + delta
|
|
return torch.stack([y, cb, cr], -3)
|
|
|
|
class YcbcrToRgb(nn.Module):
|
|
r"""Convert an image from YCbCr to Rgb.
|
|
|
|
The image data is assumed to be in the range of (0, 1).
|
|
|
|
Returns:
|
|
RGB version of the image.
|
|
|
|
Shape:
|
|
- image: :math:`(*, 3, H, W)`
|
|
- output: :math:`(*, 3, H, W)`
|
|
|
|
Examples:
|
|
>>> input = torch.rand(2, 3, 4, 5)
|
|
>>> rgb = YcbcrToRgb()
|
|
>>> output = rgb(input) # 2x3x4x5
|
|
|
|
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127
|
|
"""
|
|
|
|
def forward(self, image):
|
|
y = image[..., 0, :, :]
|
|
cb = image[..., 1, :, :]
|
|
cr = image[..., 2, :, :]
|
|
|
|
delta = 0.5
|
|
cb_shifted = cb - delta
|
|
cr_shifted = cr - delta
|
|
|
|
r = y + 1.403 * cr_shifted
|
|
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
|
|
b = y + 1.773 * cb_shifted
|
|
return torch.stack([r, g, b], -3)
|
|
|
|
|
|
|
|
|
|
class EffKAN(torch.nn.Module):
|
|
"""
|
|
https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6
|
|
"""
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
grid_size=5,
|
|
spline_order=3,
|
|
scale_noise=0.1,
|
|
scale_base=1.0,
|
|
scale_spline=1.0,
|
|
enable_standalone_scale_spline=True,
|
|
base_activation=torch.nn.SiLU,
|
|
grid_eps=0.02,
|
|
grid_range=[-1, 1],
|
|
):
|
|
super(EffKAN, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.grid_size = grid_size
|
|
self.spline_order = spline_order
|
|
|
|
h = (grid_range[1] - grid_range[0]) / grid_size
|
|
grid = (
|
|
(
|
|
torch.arange(-spline_order, grid_size + spline_order + 1) * h
|
|
+ grid_range[0]
|
|
)
|
|
.expand(in_features, -1)
|
|
.contiguous()
|
|
)
|
|
self.register_buffer("grid", grid)
|
|
|
|
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
|
self.spline_weight = torch.nn.Parameter(
|
|
torch.Tensor(out_features, in_features, grid_size + spline_order)
|
|
)
|
|
if enable_standalone_scale_spline:
|
|
self.spline_scaler = torch.nn.Parameter(
|
|
torch.Tensor(out_features, in_features)
|
|
)
|
|
|
|
self.scale_noise = scale_noise
|
|
self.scale_base = scale_base
|
|
self.scale_spline = scale_spline
|
|
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
|
self.base_activation = base_activation()
|
|
self.grid_eps = grid_eps
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
|
|
with torch.no_grad():
|
|
noise = (
|
|
(
|
|
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
|
|
- 1 / 2
|
|
)
|
|
* self.scale_noise
|
|
/ self.grid_size
|
|
)
|
|
self.spline_weight.data.copy_(
|
|
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
|
|
* self.curve2coeff(
|
|
self.grid.T[self.spline_order : -self.spline_order],
|
|
noise,
|
|
)
|
|
)
|
|
if self.enable_standalone_scale_spline:
|
|
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
|
|
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
|
|
|
|
def b_splines(self, x: torch.Tensor):
|
|
"""
|
|
Compute the B-spline bases for the given input tensor.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
|
|
|
Returns:
|
|
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
|
|
"""
|
|
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
|
|
grid: torch.Tensor = (
|
|
self.grid
|
|
) # (in_features, grid_size + 2 * spline_order + 1)
|
|
x = x.unsqueeze(-1)
|
|
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
|
for k in range(1, self.spline_order + 1):
|
|
bases = (
|
|
(x - grid[:, : -(k + 1)])
|
|
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
|
|
* bases[:, :, :-1]
|
|
) + (
|
|
(grid[:, k + 1 :] - x)
|
|
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
|
|
* bases[:, :, 1:]
|
|
)
|
|
|
|
assert bases.size() == (
|
|
x.size(0),
|
|
self.in_features,
|
|
self.grid_size + self.spline_order,
|
|
)
|
|
return bases.contiguous()
|
|
|
|
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
|
|
"""
|
|
Compute the coefficients of the curve that interpolates the given points.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
|
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
|
|
|
|
Returns:
|
|
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
|
|
"""
|
|
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
assert y.size() == (x.size(0), self.in_features, self.out_features)
|
|
|
|
A = self.b_splines(x).transpose(
|
|
0, 1
|
|
) # (in_features, batch_size, grid_size + spline_order)
|
|
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
|
|
solution = torch.linalg.lstsq(
|
|
A, B
|
|
).solution # (in_features, grid_size + spline_order, out_features)
|
|
result = solution.permute(
|
|
2, 0, 1
|
|
) # (out_features, in_features, grid_size + spline_order)
|
|
|
|
assert result.size() == (
|
|
self.out_features,
|
|
self.in_features,
|
|
self.grid_size + self.spline_order,
|
|
)
|
|
return result.contiguous()
|
|
|
|
@property
|
|
def scaled_spline_weight(self):
|
|
return self.spline_weight * (
|
|
self.spline_scaler.unsqueeze(-1)
|
|
if self.enable_standalone_scale_spline
|
|
else 1.0
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
assert x.size(-1) == self.in_features
|
|
original_shape = x.shape
|
|
x = x.view(-1, self.in_features)
|
|
|
|
base_output = F.linear(self.base_activation(x), self.base_weight)
|
|
spline_output = F.linear(
|
|
self.b_splines(x).view(x.size(0), -1),
|
|
self.scaled_spline_weight.view(self.out_features, -1),
|
|
)
|
|
output = base_output + spline_output
|
|
|
|
output = output.view(*original_shape[:-1], self.out_features)
|
|
return output
|
|
|
|
@torch.no_grad()
|
|
def update_grid(self, x: torch.Tensor, margin=0.01):
|
|
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
batch = x.size(0)
|
|
|
|
splines = self.b_splines(x) # (batch, in, coeff)
|
|
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
|
|
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
|
|
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
|
|
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
|
|
unreduced_spline_output = unreduced_spline_output.permute(
|
|
1, 0, 2
|
|
) # (batch, in, out)
|
|
|
|
# sort each channel individually to collect data distribution
|
|
x_sorted = torch.sort(x, dim=0)[0]
|
|
grid_adaptive = x_sorted[
|
|
torch.linspace(
|
|
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
|
|
)
|
|
]
|
|
|
|
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
|
|
grid_uniform = (
|
|
torch.arange(
|
|
self.grid_size + 1, dtype=torch.float32, device=x.device
|
|
).unsqueeze(1)
|
|
* uniform_step
|
|
+ x_sorted[0]
|
|
- margin
|
|
)
|
|
|
|
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
|
grid = torch.concatenate(
|
|
[
|
|
grid[:1]
|
|
- uniform_step
|
|
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
|
|
grid,
|
|
grid[-1:]
|
|
+ uniform_step
|
|
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
self.grid.copy_(grid.T)
|
|
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
|
|
|
|
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
|
"""
|
|
Compute the regularization loss.
|
|
|
|
This is a dumb simulation of the original L1 regularization as stated in the
|
|
paper, since the original one requires computing absolutes and entropy from the
|
|
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
|
|
behind the F.linear function if we want an memory efficient implementation.
|
|
|
|
The L1 regularization is now computed as mean absolute value of the spline
|
|
weights. The authors implementation also includes this term in addition to the
|
|
sample-based regularization.
|
|
"""
|
|
l1_fake = self.spline_weight.abs().mean(-1)
|
|
regularization_loss_activation = l1_fake.sum()
|
|
p = l1_fake / regularization_loss_activation
|
|
regularization_loss_entropy = -torch.sum(p * p.log())
|
|
return (
|
|
regularize_activation * regularization_loss_activation
|
|
+ regularize_entropy * regularization_loss_entropy
|
|
)
|
|
|
|
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
|
|
class ChebyKANLayer(nn.Module):
|
|
"""
|
|
https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
|
|
"""
|
|
def __init__(self, in_features, out_features, degree=8):
|
|
super(ChebyKANLayer, self).__init__()
|
|
self.inputdim = in_features
|
|
self.outdim = out_features
|
|
self.degree = degree
|
|
|
|
self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
|
|
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
|
|
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
|
|
|
|
def forward(self, x):
|
|
# Since Chebyshev polynomial is defined in [-1, 1]
|
|
# We need to normalize x to [-1, 1] using tanh
|
|
b,hw,c = x.shape
|
|
x = torch.tanh(x)
|
|
# View and repeat input degree + 1 times
|
|
x = x.view((b, hw, self.inputdim, 1)).expand(
|
|
-1, -1, -1, self.degree + 1
|
|
) # shape = (batch_size, inputdim, self.degree + 1)
|
|
# Apply acos
|
|
x = x.acos()
|
|
# Multiply by arange [0 .. degree]
|
|
x *= self.arange
|
|
# Apply cos
|
|
x = x.cos()
|
|
# Compute the Chebyshev interpolation
|
|
y = torch.einsum(
|
|
"btid,iod->bto", x, self.cheby_coeffs
|
|
) # shape = (batch_size, hw, outdim)
|
|
|
|
y = y.view(b, hw, self.outdim)
|
|
return y
|
|
|
|
|
|
class UpscaleBlockChebyKAN(nn.Module):
|
|
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8):
|
|
super(UpscaleBlockChebyKAN, self).__init__()
|
|
assert layers_count > 0
|
|
self.upscale_factor = upscale_factor
|
|
self.hidden_dim = hidden_dim
|
|
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
|
|
|
|
self.linear_projections = []
|
|
for i in range(layers_count):
|
|
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree))
|
|
self.linear_projections = nn.ModuleList(self.linear_projections)
|
|
|
|
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
|
|
|
|
self.in_bias = self.in_scale = input_max_value/2
|
|
self.out_bias = self.out_scale = output_max_value/2
|
|
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
|
|
|
|
def forward(self, x):
|
|
x = (x-self.in_bias)/self.in_scale
|
|
x = self.embed(x)
|
|
for linear_projection in self.linear_projections:
|
|
x = self.layer_norm(linear_projection(x))
|
|
x = self.project_channels(x)
|
|
x = torch.tanh(x)
|
|
x = x*self.out_scale + self.out_bias
|
|
return x
|
|
|
|
class UpscaleBlockEffKAN(nn.Module):
|
|
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
|
|
super(UpscaleBlockEffKAN, self).__init__()
|
|
assert layers_count > 0
|
|
self.upscale_factor = upscale_factor
|
|
self.hidden_dim = hidden_dim
|
|
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
|
|
|
|
self.linear_projections = []
|
|
for i in range(layers_count):
|
|
self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True))
|
|
self.linear_projections = nn.ModuleList(self.linear_projections)
|
|
|
|
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
|
|
|
|
self.in_bias = self.in_scale = input_max_value/2
|
|
self.out_bias = self.out_scale = output_max_value/2
|
|
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
|
|
|
|
def forward(self, x):
|
|
x = (x-self.in_bias)/self.in_scale
|
|
x = self.embed(x)
|
|
for linear_projection in self.linear_projections:
|
|
x = self.layer_norm(linear_projection(x))
|
|
x = self.project_channels(x)
|
|
x = torch.tanh(x)
|
|
x = x*self.out_scale + self.out_bias
|
|
return x
|
|
|
|
|
|
class ComplexGaborLayer2D(nn.Module):
|
|
'''
|
|
Implicit representation with complex Gabor nonlinearity with 2D activation function
|
|
https://github.com/vishwa91/wire/blob/main/modules/wire2d.py
|
|
Inputs;
|
|
in_features: Input features
|
|
out_features; Output features
|
|
bias: if True, enable bias for the linear operation
|
|
is_first: Legacy SIREN parameter
|
|
omega_0: Legacy SIREN parameter
|
|
omega0: Frequency of Gabor sinusoid term
|
|
sigma0: Scaling of Gabor Gaussian term
|
|
trainable: If True, omega and sigma are trainable parameters
|
|
'''
|
|
|
|
def __init__(self, in_features, out_features, bias=True,
|
|
is_first=False, omega0=10.0, sigma0=10.0,
|
|
trainable=False):
|
|
super().__init__()
|
|
self.omega_0 = omega0
|
|
self.scale_0 = sigma0
|
|
self.is_first = is_first
|
|
|
|
self.in_features = in_features
|
|
|
|
if self.is_first:
|
|
dtype = torch.float
|
|
else:
|
|
dtype = torch.cfloat
|
|
|
|
# Set trainable parameters if they are to be simultaneously optimized
|
|
self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
|
|
self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
|
|
|
|
self.linear = nn.Linear(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype)
|
|
|
|
# Second Gaussian window
|
|
self.scale_orth = nn.Linear(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype)
|
|
|
|
def forward(self, input):
|
|
lin = self.linear(input)
|
|
|
|
scale_x = lin
|
|
scale_y = self.scale_orth(input)
|
|
|
|
freq_term = torch.exp(1j*self.omega_0*lin)
|
|
|
|
arg = scale_x.abs().square() + scale_y.abs().square()
|
|
gauss_term = torch.exp(-self.scale_0*self.scale_0*arg)
|
|
|
|
return freq_term*gauss_term |