import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .utils import round_func class PercievePattern(): """ Coordinates scheme: [channel, height, width]. Channel can be ommited for all channels. Examples: 1. receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]] 2. receptive_field_idxes=[[0,0,0],[0,1,0],[1,1,0],[1,1]] """ def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2, channels=1): assert window_size >= (np.max([x for y in receptive_field_idxes for x in y])+1) tmp = [] for coords in receptive_field_idxes: if len(coords) < 3: for i in range(channels): tmp.append([i,] + coords) if len(coords) == 3: tmp.append(coords) receptive_field_idxes = np.array(sorted(tmp)) self.window_size = window_size self.center = center self.receptive_field_idxes = [receptive_field_idxes[i,0]*self.window_size*self.window_size + receptive_field_idxes[i,1]*self.window_size + receptive_field_idxes[i,2] for i in range(len(receptive_field_idxes))] assert len(np.unique(self.receptive_field_idxes) == len(self.receptive_field_idxes)), "Duplicated coordinates found. Coordinates scheme: [channel, height, width]." def __call__(self, x): b,c,h,w = x.shape x = F.pad( x, pad=[ self.center[0], self.window_size-self.center[0]-1, self.center[1], self.window_size-self.center[1]-1 ], mode='replicate' ) x = F.unfold(input=x, kernel_size=self.window_size) x = torch.stack([x[:,self.receptive_field_idxes[i],:] for i in range(len(self.receptive_field_idxes))], 2) return x class UpscaleBlock(nn.Module): def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): super(UpscaleBlock, self).__init__() assert layers_count > 0 self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) self.linear_projections = [] for i in range(layers_count): self.linear_projections.append(nn.Linear(in_features=(i+1)*hidden_dim, out_features=hidden_dim, bias=True)) self.linear_projections = nn.ModuleList(self.linear_projections) self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) self.in_bias = self.in_scale = input_max_value/2 self.out_bias = self.out_scale = output_max_value/2 def forward(self, x): x = (x-self.in_bias)/self.in_scale x = torch.nn.functional.gelu(self.embed(x)) for linear_projection in self.linear_projections: x = torch.cat([x, torch.nn.functional.gelu(linear_projection(x))], dim=2) x = self.project_channels(x) x = torch.tanh(x) x = x*self.out_scale + self.out_bias return x class RgbToYcbcr(nn.Module): r"""Convert an image from RGB to YCbCr. The image data is assumed to be in the range of (0, 1). Returns: YCbCr version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> ycbcr = RgbToYcbcr() >>> output = ycbcr(input) # 2x3x4x5 https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L105 """ def forward(self, image): r = image[..., 0, :, :] g = image[..., 1, :, :] b = image[..., 2, :, :] delta = 0.5 y = 0.299 * r + 0.587 * g + 0.114 * b cb = (b - y) * 0.564 + delta cr = (r - y) * 0.713 + delta return torch.stack([y, cb, cr], -3) class YcbcrToRgb(nn.Module): r"""Convert an image from YCbCr to Rgb. The image data is assumed to be in the range of (0, 1). Returns: RGB version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> rgb = YcbcrToRgb() >>> output = rgb(input) # 2x3x4x5 https://github.com/kornia/kornia/blob/2c084f8dc108b3f0f3c8983ac3f25bf88638d01a/kornia/color/ycbcr.py#L127 """ def forward(self, image): y = image[..., 0, :, :] cb = image[..., 1, :, :] cr = image[..., 2, :, :] delta = 0.5 cb_shifted = cb - delta cr_shifted = cr - delta r = y + 1.403 * cr_shifted g = y - 0.714 * cr_shifted - 0.344 * cb_shifted b = y + 1.773 * cb_shifted return torch.stack([r, g, b], -3) class EffKAN(torch.nn.Module): """ https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6 """ def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(EffKAN, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.size(-1) == self.in_features original_shape = x.shape x = x.view(-1, self.in_features) base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) output = base_output + spline_output output = output.view(*original_shape[:-1], self.out_features) return output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) # This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients class ChebyKANLayer(nn.Module): """ https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py """ def __init__(self, in_features, out_features, degree=8): super(ChebyKANLayer, self).__init__() self.inputdim = in_features self.outdim = out_features self.degree = degree self.cheby_coeffs = nn.Parameter(torch.empty(in_features, out_features, degree + 1)) nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (in_features * (degree + 1))) self.register_buffer("arange", torch.arange(0, degree + 1, 1)) def forward(self, x): # Since Chebyshev polynomial is defined in [-1, 1] # We need to normalize x to [-1, 1] using tanh b,hw,c = x.shape x = torch.tanh(x) # View and repeat input degree + 1 times x = x.view((b, hw, self.inputdim, 1)).expand( -1, -1, -1, self.degree + 1 ) # shape = (batch_size, inputdim, self.degree + 1) # Apply acos x = x.acos() # Multiply by arange [0 .. degree] x *= self.arange # Apply cos x = x.cos() # Compute the Chebyshev interpolation y = torch.einsum( "btid,iod->bto", x, self.cheby_coeffs ) # shape = (batch_size, hw, outdim) y = y.view(b, hw, self.outdim) return y class UpscaleBlockChebyKAN(nn.Module): def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255, degree=8): super(UpscaleBlockChebyKAN, self).__init__() assert layers_count > 0 self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) self.linear_projections = [] for i in range(layers_count): self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, degree=degree)) self.linear_projections = nn.ModuleList(self.linear_projections) self.project_channels = nn.Linear(in_features=hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) self.in_bias = self.in_scale = input_max_value/2 self.out_bias = self.out_scale = output_max_value/2 self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh def forward(self, x): x = (x-self.in_bias)/self.in_scale x = self.embed(x) for linear_projection in self.linear_projections: x = self.layer_norm(linear_projection(x)) x = self.project_channels(x) x = torch.tanh(x) x = x*self.out_scale + self.out_bias return x class UpscaleBlockEffKAN(nn.Module): def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): super(UpscaleBlockEffKAN, self).__init__() assert layers_count > 0 self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) self.linear_projections = [] for i in range(layers_count): self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True)) self.linear_projections = nn.ModuleList(self.linear_projections) self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) self.in_bias = self.in_scale = input_max_value/2 self.out_bias = self.out_scale = output_max_value/2 self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh def forward(self, x): x = (x-self.in_bias)/self.in_scale x = self.embed(x) for linear_projection in self.linear_projections: x = self.layer_norm(linear_projection(x)) x = self.project_channels(x) x = torch.tanh(x) x = x*self.out_scale + self.out_bias return x class ComplexGaborLayer2D(nn.Module): ''' Implicit representation with complex Gabor nonlinearity with 2D activation function https://github.com/vishwa91/wire/blob/main/modules/wire2d.py Inputs; in_features: Input features out_features; Output features bias: if True, enable bias for the linear operation is_first: Legacy SIREN parameter omega_0: Legacy SIREN parameter omega0: Frequency of Gabor sinusoid term sigma0: Scaling of Gabor Gaussian term trainable: If True, omega and sigma are trainable parameters ''' def __init__(self, in_features, out_features, bias=True, is_first=False, omega0=10.0, sigma0=10.0, trainable=False): super().__init__() self.omega_0 = omega0 self.scale_0 = sigma0 self.is_first = is_first self.in_features = in_features if self.is_first: dtype = torch.float else: dtype = torch.cfloat # Set trainable parameters if they are to be simultaneously optimized self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable) self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable) self.linear = nn.Linear(in_features, out_features, bias=bias, dtype=dtype) # Second Gaussian window self.scale_orth = nn.Linear(in_features, out_features, bias=bias, dtype=dtype) def forward(self, input): lin = self.linear(input) scale_x = lin scale_y = self.scale_orth(input) freq_term = torch.exp(1j*self.omega_0*lin) arg = scale_x.abs().square() + scale_y.abs().square() gauss_term = torch.exp(-self.scale_0*self.scale_0*arg) return freq_term*gauss_term