You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

529 lines
20 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .utils import round_func
class PercievePattern():
"""
Coordinates scheme: [channel, height, width]. Channel can be ommited for all channels.
Examples:
1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]
2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]]
"""
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1):
assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1)
tmp = []
for coords in receptive_field_idxes:
if len(coords) < 3:
for i in range(channels):
tmp.append([i,] + coords)
if len(coords) == 3:
tmp.append(coords)
receptive_field_idxes = np.array(sorted(tmp))
self.window_size = window_size
self.center = center
self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))]
assert len(np.unique(self.receptive_field_idxes) == len(self.receptive_field_idxes)), "Duplicated coordinates found. Coordinates scheme: [channel, height, width]."
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(
x,
pad=[
self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1
],
mode='replicate'
)
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2)
return x
class UpscaleBlock(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(UpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = torch.nn.functional.gelu(self.embed(x))
for linear_projection in self.linear_projections:
x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2)
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr.
The image data is assumed to be in the range of (0, 1).
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105
"""
def forward(self, image):
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
delta = 0.5
y = 0.299 * r + 0.587 * g + 0.114 * b
cb = (b - y) * 0.564 + delta
cr = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127
"""
def forward(self, image):
y = image[..., 0, :, :]
cb = image[..., 1, :, :]
cr = image[..., 2, :, :]
delta = 0.5
cb_shifted = cb - delta
cr_shifted = cr - delta
r = y + 1.403 * cr_shifted
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)
class EffKAN(torch.nn.Module):
"""
https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6
"""
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(EffKAN, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.size(-1) == self.in_features
original_shape = x.shape
x = x.view(-1, self.in_features)
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output
output = output.view(*original_shape[:-1], self.out_features)
return output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
"""
https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
"""
def __init__(self, in_features, out_features, degree=8):
super(ChebyKANLayer, self).__init__()
self.inputdim = in_features
self.outdim = out_features
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x):
# Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
b,hw,c = x.shape
x = torch.tanh(x)
# View and repeat input degree + 1 times
x = x.view((b, hw, self.inputdim, 1)).expand(
-1, -1, -1, self.degree + 1
) # shape = (batch_size, inputdim, self.degree + 1)
# Apply acos
x = x.acos()
# Multiply by arange [0 .. degree]
x *= self.arange
# Apply cos
x = x.cos()
# Compute the Chebyshev interpolation
y = torch.einsum(
"btid,iod->bto", x, self.cheby_coeffs
) # shape = (batch_size, hw, outdim)
y = y.view(b, hw, self.outdim)
return y
class UpscaleBlockChebyKAN(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8):
super(UpscaleBlockChebyKAN, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class UpscaleBlockEffKAN(nn.Module):
def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
super(UpscaleBlockEffKAN, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
self.linear_projections = []
for i in range(layers_count):
self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True))
self.linear_projections = nn.ModuleList(self.linear_projections)
self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
self.in_bias = self.in_scale = input_max_value/2
self.out_bias = self.out_scale = output_max_value/2
self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x):
x = (x-self.in_bias)/self.in_scale
x = self.embed(x)
for linear_projection in self.linear_projections:
x = self.layer_norm(linear_projection(x))
x = self.project_channels(x)
x = torch.tanh(x)
x = x*self.out_scale + self.out_bias
return x
class ComplexGaborLayer2D(nn.Module):
'''
Implicit representation with complex Gabor nonlinearity with 2D activation function
https://github.com/vishwa91/wire/blob/main/modules/wire2d.py
Inputs;
in_features: Input features
out_features; Output features
bias: if True, enable bias for the linear operation
is_first: Legacy SIREN parameter
omega_0: Legacy SIREN parameter
omega0: Frequency of Gabor sinusoid term
sigma0: Scaling of Gabor Gaussian term
trainable: If True, omega and sigma are trainable parameters
'''
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega0=10.0, sigma0=10.0,
trainable=False):
super().__init__()
self.omega_0 = omega0
self.scale_0 = sigma0
self.is_first = is_first
self.in_features = in_features
if self.is_first:
dtype = torch.float
else:
dtype = torch.cfloat
# Set trainable parameters if they are to be simultaneously optimized
self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
self.linear = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
# Second Gaussian window
self.scale_orth = nn.Linear(in_features,
out_features,
bias=bias,
dtype=dtype)
def forward(self, input):
lin = self.linear(input)
scale_x = lin
scale_y = self.scale_orth(input)
freq_term = torch.exp(1j*self.omega_0*lin)
arg = scale_x.abs().square() + scale_y.abs().square()
gauss_term = torch.exp(-self.scale_0*self.scale_0*arg)
return freq_term*gauss_term