| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -115,4 +115,339 @@ class YcbcrToRgb(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        r = y + 1.403 * cr_shifted
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        b = y + 1.773 * cb_shifted
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return torch.stack([r, g, b], -3)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return torch.stack([r, g, b], -3)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class EffKAN(torch.nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    https://github.com/Blealtan/efficient-kan/blob/605b8c2ae24b8085bd11b96896a17eabe12f6736/src/efficient_kan/kan.py#L6
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        in_features,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        out_features,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid_size=5,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        spline_order=3,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        scale_noise=0.1,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        scale_base=1.0,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        scale_spline=1.0,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        enable_standalone_scale_spline=True,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        base_activation=torch.nn.SiLU,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid_eps=0.02,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid_range=[-1, 1],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super(EffKAN, self).__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.in_features = in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.out_features = out_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.grid_size = grid_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.spline_order = spline_order
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        h = (grid_range[1] - grid_range[0]) / grid_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                torch.arange(-spline_order, grid_size + spline_order + 1) * h
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                + grid_range[0]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            .expand(in_features, -1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            .contiguous()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.register_buffer("grid", grid)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.spline_weight = torch.nn.Parameter(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            torch.Tensor(out_features, in_features, grid_size + spline_order)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if enable_standalone_scale_spline:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.spline_scaler = torch.nn.Parameter(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                torch.Tensor(out_features, in_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.scale_noise = scale_noise
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.scale_base = scale_base
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.scale_spline = scale_spline
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.enable_standalone_scale_spline = enable_standalone_scale_spline
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.base_activation = base_activation()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.grid_eps = grid_eps
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.reset_parameters()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def reset_parameters(self):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        with torch.no_grad():
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            noise = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    - 1 / 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * self.scale_noise
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                / self.grid_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.spline_weight.data.copy_(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * self.curve2coeff(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    self.grid.T[self.spline_order : -self.spline_order],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    noise,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            if self.enable_standalone_scale_spline:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def b_splines(self, x: torch.Tensor):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Compute the B-spline bases for the given input tensor.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Args:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Returns:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert x.dim() == 2 and x.size(1) == self.in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid: torch.Tensor = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.grid
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # (in_features, grid_size + 2 * spline_order + 1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.unsqueeze(-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for k in range(1, self.spline_order + 1):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            bases = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                (x - grid[:, : -(k + 1)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                / (grid[:, k:-1] - grid[:, : -(k + 1)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * bases[:, :, :-1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            ) + (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                (grid[:, k + 1 :] - x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * bases[:, :, 1:]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert bases.size() == (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x.size(0),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.in_features,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.grid_size + self.spline_order,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return bases.contiguous()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Compute the coefficients of the curve that interpolates the given points.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Args:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Returns:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert x.dim() == 2 and x.size(1) == self.in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert y.size() == (x.size(0), self.in_features, self.out_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        A = self.b_splines(x).transpose(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            0, 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # (in_features, batch_size, grid_size + spline_order)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        solution = torch.linalg.lstsq(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            A, B
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        ).solution  # (in_features, grid_size + spline_order, out_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        result = solution.permute(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            2, 0, 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # (out_features, in_features, grid_size + spline_order)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert result.size() == (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.out_features,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.in_features,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.grid_size + self.spline_order,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return result.contiguous()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    @property
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def scaled_spline_weight(self):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return self.spline_weight * (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.spline_scaler.unsqueeze(-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            if self.enable_standalone_scale_spline
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            else 1.0
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x: torch.Tensor):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert x.size(-1) == self.in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        original_shape = x.shape
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.view(-1, self.in_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        base_output = F.linear(self.base_activation(x), self.base_weight)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        spline_output = F.linear(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.b_splines(x).view(x.size(0), -1),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.scaled_spline_weight.view(self.out_features, -1),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output = base_output + spline_output
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output = output.view(*original_shape[:-1], self.out_features)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return output
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    @torch.no_grad()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def update_grid(self, x: torch.Tensor, margin=0.01):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert x.dim() == 2 and x.size(1) == self.in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        batch = x.size(0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        splines = self.b_splines(x)  # (batch, in, coeff)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        unreduced_spline_output = unreduced_spline_output.permute(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            1, 0, 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # (batch, in, out)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # sort each channel individually to collect data distribution
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x_sorted = torch.sort(x, dim=0)[0]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid_adaptive = x_sorted[
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            torch.linspace(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid_uniform = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            torch.arange(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                self.grid_size + 1, dtype=torch.float32, device=x.device
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            ).unsqueeze(1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            * uniform_step
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            + x_sorted[0]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            - margin
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        grid = torch.concatenate(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                grid[:1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                - uniform_step
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                grid,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                grid[-1:]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                + uniform_step
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            dim=0,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.grid.copy_(grid.T)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Compute the regularization loss.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        This is a dumb simulation of the original L1 regularization as stated in the
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        paper, since the original one requires computing absolutes and entropy from the
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        behind the F.linear function if we want an memory efficient implementation.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        The L1 regularization is now computed as mean absolute value of the spline
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        weights. The authors implementation also includes this term in addition to the
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        sample-based regularization.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        l1_fake = self.spline_weight.abs().mean(-1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        regularization_loss_activation = l1_fake.sum()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        p = l1_fake / regularization_loss_activation
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        regularization_loss_entropy = -torch.sum(p * p.log())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            regularize_activation * regularization_loss_activation
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            + regularize_entropy * regularization_loss_entropy
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class ChebyKANLayer(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, in_features, out_features, degree=5):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super(ChebyKANLayer, self).__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.inputdim = in_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.outdim = out_features
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.degree = degree
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.register_buffer("arange", torch.arange(0, degree + 1, 1))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # Since Chebyshev polynomial is defined in [-1, 1]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # We need to normalize x to [-1, 1] using tanh
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        b,hw,c = x.shape
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.tanh(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # View and repeat input degree + 1 times
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.view((b, hw, self.inputdim, 1)).expand(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            -1, -1, -1, self.degree + 1
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # shape = (batch_size, inputdim, self.degree + 1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # Apply acos
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.acos()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # Multiply by arange [0 .. degree]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x *= self.arange
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # Apply cos
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.cos()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # Compute the Chebyshev interpolation
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        y = torch.einsum(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            "btid,iod->bto", x, self.cheby_coeffs
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )  # shape = (batch_size, hw, outdim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        y = y.view(b, hw, self.outdim)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return y
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class UpscaleBlockChebyKAN(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super(UpscaleBlockChebyKAN, self).__init__()   
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert layers_count > 0     
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.upscale_factor = upscale_factor 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.hidden_dim = hidden_dim
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.linear_projections = []
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for i in range(layers_count):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.linear_projections.append(ChebyKANLayer(in_features=hidden_dim, out_features=hidden_dim, bias=True))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.linear_projections = nn.ModuleList(self.linear_projections)          
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.in_bias = self.in_scale = input_max_value/2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.out_bias = self.out_scale = output_max_value/2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = (x-self.in_bias)/self.in_scale
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = self.embed(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for linear_projection in self.linear_projections:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x = self.layer_norm(linear_projection(x)) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = self.project_channels(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.tanh(x)         
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x*self.out_scale + self.out_bias
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x  
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class UpscaleBlockEffKAN(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, in_features=4, hidden_dim = 32, layers_count=4, upscale_factor=1, input_max_value=255, output_max_value=255):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super(UpscaleBlockEffKAN, self).__init__()   
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert layers_count > 0     
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.upscale_factor = upscale_factor 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.hidden_dim = hidden_dim
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.embed = nn.Linear(in_features=in_features, out_features=hidden_dim, bias=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.linear_projections = []
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for i in range(layers_count):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            self.linear_projections.append(EffKAN(in_features=hidden_dim, out_features=hidden_dim, bias=True))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.linear_projections = nn.ModuleList(self.linear_projections)          
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.project_channels = nn.Linear(in_features=(layers_count+1)*hidden_dim, out_features=upscale_factor * upscale_factor, bias=True)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.in_bias = self.in_scale = input_max_value/2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.out_bias = self.out_scale = output_max_value/2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.layer_norm = nn.LayerNorm(hidden_dim) # To avoid gradient vanishing caused by tanh
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = (x-self.in_bias)/self.in_scale
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = self.embed(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for linear_projection in self.linear_projections:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x = self.layer_norm(linear_projection(x)) 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = self.project_channels(x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.tanh(x)         
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x*self.out_scale + self.out_bias
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x  
 |