diff --git a/src/common/layers.py b/src/common/layers.py index 1b2e0b7..38b1b54 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -115,4 +115,339 @@ class YcbcrToRgb(nn.Module): r = y + 1.403 * cr_shifted g = y - 0.714 * cr_shifted - 0.344 * cb_shifted b = y + 1.773 * cb_shifted - return torch.stack([r, g, b], -3) \ No newline at end of file + return torch.stack([r, g, b], -3) + + + + +class EffKAN(torch.nn.Module): + """ + https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6 + """ + def __init__( + self, + in_features, + out_features, + grid_size=5, + spline_order=3, + scale_noise=0.1, + scale_base=1.0, + scale_spline=1.0, + enable_standalone_scale_spline=True, + base_activation=torch.nn.SiLU, + grid_eps=0.02, + grid_range=[-1, 1], + ): + super(EffKAN, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.grid_size = grid_size + self.spline_order = spline_order + + h = (grid_range[1] - grid_range[0]) / grid_size + grid = ( + ( + torch.arange(-spline_order, grid_size + spline_order + 1) * h + + grid_range[0] + ) + .expand(in_features, -1) + .contiguous() + ) + self.register_buffer("grid", grid) + + self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + self.spline_weight = torch.nn.Parameter( + torch.Tensor(out_features, in_features, grid_size + spline_order) + ) + if enable_standalone_scale_spline: + self.spline_scaler = torch.nn.Parameter( + torch.Tensor(out_features, in_features) + ) + + self.scale_noise = scale_noise + self.scale_base = scale_base + self.scale_spline = scale_spline + self.enable_standalone_scale_spline = enable_standalone_scale_spline + self.base_activation = base_activation() + self.grid_eps = grid_eps + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) + with torch.no_grad(): + noise = ( + ( + torch.rand(self.grid_size + 1, self.in_features, self.out_features) + - 1 / 2 + ) + * self.scale_noise + / self.grid_size + ) + self.spline_weight.data.copy_( + (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) + * self.curve2coeff( + self.grid.T[self.spline_order : -self.spline_order], + noise, + ) + ) + if self.enable_standalone_scale_spline: + # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) + torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) + + def b_splines(self, x: torch.Tensor): + """ + Compute the B-spline bases for the given input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + + Returns: + torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). + """ + assert x.dim() == 2 and x.size(1) == self.in_features + + grid: torch.Tensor = ( + self.grid + ) # (in_features, grid_size + 2 * spline_order + 1) + x = x.unsqueeze(-1) + bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) + for k in range(1, self.spline_order + 1): + bases = ( + (x - grid[:, : -(k + 1)]) + / (grid[:, k:-1] - grid[:, : -(k + 1)]) + * bases[:, :, :-1] + ) + ( + (grid[:, k + 1 :] - x) + / (grid[:, k + 1 :] - grid[:, 1:(-k)]) + * bases[:, :, 1:] + ) + + assert bases.size() == ( + x.size(0), + self.in_features, + self.grid_size + self.spline_order, + ) + return bases.contiguous() + + def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): + """ + Compute the coefficients of the curve that interpolates the given points. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). + + Returns: + torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). + """ + assert x.dim() == 2 and x.size(1) == self.in_features + assert y.size() == (x.size(0), self.in_features, self.out_features) + + A = self.b_splines(x).transpose( + 0, 1 + ) # (in_features, batch_size, grid_size + spline_order) + B = y.transpose(0, 1) # (in_features, batch_size, out_features) + solution = torch.linalg.lstsq( + A, B + ).solution # (in_features, grid_size + spline_order, out_features) + result = solution.permute( + 2, 0, 1 + ) # (out_features, in_features, grid_size + spline_order) + + assert result.size() == ( + self.out_features, + self.in_features, + self.grid_size + self.spline_order, + ) + return result.contiguous() + + @property + def scaled_spline_weight(self): + return self.spline_weight * ( + self.spline_scaler.unsqueeze(-1) + if self.enable_standalone_scale_spline + else 1.0 + ) + + def forward(self, x: torch.Tensor): + assert x.size(-1) == self.in_features + original_shape = x.shape + x = x.view(-1, self.in_features) + + base_output = F.linear(self.base_activation(x), self.base_weight) + spline_output = F.linear( + self.b_splines(x).view(x.size(0), -1), + self.scaled_spline_weight.view(self.out_features, -1), + ) + output = base_output + spline_output + + output = output.view(*original_shape[:-1], self.out_features) + return output + + @torch.no_grad() + def update_grid(self, x: torch.Tensor, margin=0.01): + assert x.dim() == 2 and x.size(1) == self.in_features + batch = x.size(0) + + splines = self.b_splines(x) # (batch, in, coeff) + splines = splines.permute(1, 0, 2) # (in, batch, coeff) + orig_coeff = self.scaled_spline_weight # (out, in, coeff) + orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) + unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) + unreduced_spline_output = unreduced_spline_output.permute( + 1, 0, 2 + ) # (batch, in, out) + + # sort each channel individually to collect data distribution + x_sorted = torch.sort(x, dim=0)[0] + grid_adaptive = x_sorted[ + torch.linspace( + 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device + ) + ] + + uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size + grid_uniform = ( + torch.arange( + self.grid_size + 1, dtype=torch.float32, device=x.device + ).unsqueeze(1) + * uniform_step + + x_sorted[0] + - margin + ) + + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + grid = torch.concatenate( + [ + grid[:1] + - uniform_step + * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), + grid, + grid[-1:] + + uniform_step + * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), + ], + dim=0, + ) + + self.grid.copy_(grid.T) + self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) + + def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): + """ + Compute the regularization loss. + + This is a dumb simulation of the original L1 regularization as stated in the + paper, since the original one requires computing absolutes and entropy from the + expanded (batch, in_features, out_features) intermediate tensor, which is hidden + behind the F.linear function if we want an memory efficient implementation. + + The L1 regularization is now computed as mean absolute value of the spline + weights. The authors implementation also includes this term in addition to the + sample-based regularization. + """ + l1_fake = self.spline_weight.abs().mean(-1) + regularization_loss_activation = l1_fake.sum() + p = l1_fake / regularization_loss_activation + regularization_loss_entropy = -torch.sum(p * p.log()) + return ( + regularize_activation * regularization_loss_activation + + regularize_entropy * regularization_loss_entropy + ) + +# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients +class ChebyKANLayer(nn.Module): + """ + https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py + """ + def __init__(self, in_features, out_features, degree=5): + super(ChebyKANLayer, self).__init__() + self.inputdim = in_features + self.outdim = out_features + self.degree = degree + + self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1)) + nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1))) + self.register_buffer("arange", torch.arange(0, degree + 1, 1)) + + def forward(self, x): + # Since Chebyshev polynomial is defined in [-1, 1] + # We need to normalize x to [-1, 1] using tanh + b,hw,c = x.shape + x = torch.tanh(x) + # View and repeat input degree + 1 times + x = x.view((b, hw, self.inputdim, 1)).expand( + -1, -1, -1, self.degree + 1 + ) # shape = (batch_size, inputdim, self.degree + 1) + # Apply acos + x = x.acos() + # Multiply by arange [0 .. degree] + x *= self.arange + # Apply cos + x = x.cos() + # Compute the Chebyshev interpolation + y = torch.einsum( + "btid,iod->bto", x, self.cheby_coeffs + ) # shape = (batch_size, hw, outdim) + y = y.view(b, hw, self.outdim) + return y + + +class UpscaleBlockChebyKAN(nn.Module): + def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): + super(UpscaleBlockChebyKAN, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + self.hidden_dim = hidden_dim + self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) + + self.linear_projections = [] + for i in range(layers_count): + self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, bias=True)) + self.linear_projections = nn.ModuleList(self.linear_projections) + + self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) + + self.in_bias = self.in_scale = input_max_value/2 + self.out_bias = self.out_scale = output_max_value/2 + self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh + + def forward(self, x): + x = (x-self.in_bias)/self.in_scale + x = self.embed(x) + for linear_projection in self.linear_projections: + x = self.layer_norm(linear_projection(x)) + x = self.project_channels(x) + x = torch.tanh(x) + x = x*self.out_scale + self.out_bias + return x + +class UpscaleBlockEffKAN(nn.Module): + def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255): + super(UpscaleBlockEffKAN, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + self.hidden_dim = hidden_dim + self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True) + + self.linear_projections = [] + for i in range(layers_count): + self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True)) + self.linear_projections = nn.ModuleList(self.linear_projections) + + self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True) + + self.in_bias = self.in_scale = input_max_value/2 + self.out_bias = self.out_scale = output_max_value/2 + self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh + + def forward(self, x): + x = (x-self.in_bias)/self.in_scale + x = self.embed(x) + for linear_projection in self.linear_projections: + x = self.layer_norm(linear_projection(x)) + x = self.project_channels(x) + x = torch.tanh(x) + x = x*self.out_scale + self.out_bias + return x \ No newline at end of file