new framework impl. added chebykan and linear models.
							parent
							
								
									67bd678763
								
							
						
					
					
						commit
						d34cc7833e
					
				
											
												
													File diff suppressed because one or more lines are too long
												
											
										
									
								@ -0,0 +1,24 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
from common import layers
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
class SRBase(nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(SRBase, self).__init__()
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
    # def get_loss_fn(self):
 | 
			
		||||
    #     ssim_loss = losses.SSIM(data_range=255)
 | 
			
		||||
    #     l1_loss = losses.CharbonnierLoss()
 | 
			
		||||
    #     def loss_fn(pred, target):
 | 
			
		||||
    #         return ssim_loss(pred, target) + l1_loss(pred, target) 
 | 
			
		||||
    #     return loss_fn
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,26 @@
 | 
			
		||||
from common import layers
 | 
			
		||||
 | 
			
		||||
class Transferer():
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.registered_types = {}
 | 
			
		||||
 | 
			
		||||
    def register(self, input_class, output_class):
 | 
			
		||||
        self.registered_types[input_class] = output_class
 | 
			
		||||
 | 
			
		||||
    def transfer(self, input_model, batch_size, quantization_interval):
 | 
			
		||||
        input_class = input_model.__class__
 | 
			
		||||
        if not input_class in self.registered_types:
 | 
			
		||||
            raise Exception(f"No transfer function is registered for class {input_class}")
 | 
			
		||||
        transfered_model = self.transfer_model(input_model, self.registered_types[input_class], batch_size, quantization_interval)
 | 
			
		||||
        return transfered_model
 | 
			
		||||
 | 
			
		||||
    def transfer_model(self, model, output_model_class, batch_size, quantization_interval):
 | 
			
		||||
        qmodel = output_model_class(config = model.config)
 | 
			
		||||
        model.config.quantization_interval = quantization_interval
 | 
			
		||||
        for attr, value in model.named_children():
 | 
			
		||||
            if isinstance(value, layers.UpscaleBlock):
 | 
			
		||||
                getattr(qmodel, attr).stage = getattr(model, attr).stage.get_lut_model(quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        return qmodel
 | 
			
		||||
 | 
			
		||||
TRANSFERER = Transferer()
 | 
			
		||||
@ -1,28 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
 | 
			
		||||
class SRNetBase(nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(SRNetBase, self).__init__()
 | 
			
		||||
    
 | 
			
		||||
    def forward_stage(self, x, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        scale = stage.upscale_factor
 | 
			
		||||
        x = percieve_pattern(x)
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, 1, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0, 1, 2, 4, 3, 5)
 | 
			
		||||
        x = x.reshape(b, 1, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
@ -1,134 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from common.lut import select_index_4dlut_tetrahedral
 | 
			
		||||
from common import layers 
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
 | 
			
		||||
class HDBLut(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(HDBLut, self).__init__()     
 | 
			
		||||
        assert scale == 4
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
 | 
			
		||||
        self.stage1_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage1_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage1_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage1_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage1_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
        self.stage2_3H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage2_3D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage2_3B = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*3 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage2_2H = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
 | 
			
		||||
        self.stage2_2D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*2 + (2,2)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, 
 | 
			
		||||
        stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D 
 | 
			
		||||
    ):   
 | 
			
		||||
        # quantization_interval = 256//(stage1_3H.shape[0]-1)
 | 
			
		||||
        quantization_interval = 16
 | 
			
		||||
        lut_model = HDBLut(quantization_interval=quantization_interval, scale=4)
 | 
			
		||||
        lut_model.stage1_3H = nn.Parameter(torch.tensor(stage1_3H).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_3D = nn.Parameter(torch.tensor(stage1_3D).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_3B = nn.Parameter(torch.tensor(stage1_3B).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_2H = nn.Parameter(torch.tensor(stage1_2H).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_2D = nn.Parameter(torch.tensor(stage1_2D).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
        lut_model.stage2_3H = nn.Parameter(torch.tensor(stage2_3H).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_3D = nn.Parameter(torch.tensor(stage2_3D).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_3B = nn.Parameter(torch.tensor(stage2_3B).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_2H = nn.Parameter(torch.tensor(stage2_2H).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_2D = nn.Parameter(torch.tensor(stage2_2D).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        print(np.prod(x.shape))
 | 
			
		||||
        x = percieve_pattern(x)
 | 
			
		||||
        shifts = torch.tensor([lut.shape[0]**d for d in range(len(lut.shape)-2)], device=x.device).flip(0).reshape(1,1,len(lut.shape)-2)
 | 
			
		||||
        print(x.shape, x.min(), x.max())
 | 
			
		||||
        x = torch.sum(x * shifts, dim=-1)
 | 
			
		||||
        print(x.shape)
 | 
			
		||||
        lut = torch.clamp(lut, 0, 255)
 | 
			
		||||
        lut = lut.reshape(-1, scale, scale)
 | 
			
		||||
        x = x.flatten().type(torch.int64)
 | 
			
		||||
        x = lut[x]
 | 
			
		||||
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        print(x.shape)
 | 
			
		||||
        # raise RuntimeError
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb        
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(4):
 | 
			
		||||
           rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= 4*3
 | 
			
		||||
        output_lsb /= 4*2
 | 
			
		||||
        print(output_msb.min(), output_msb.max())
 | 
			
		||||
        print(output_lsb.min(), output_lsb.max())
 | 
			
		||||
 | 
			
		||||
        output_msb = output_msb + output_lsb
 | 
			
		||||
        x = output_msb
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb        
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
 | 
			
		||||
        print("STAGE2", msb.min(), msb.max(), lsb.min(), lsb.max())
 | 
			
		||||
        for rotations_count in range(4):
 | 
			
		||||
           rotated_msb = torch.floor_divide(torch.rot90(msb, k=rotations_count, dims=[2, 3]), 16)
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb += torch.rot90(self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= 4*3
 | 
			
		||||
        output_lsb /= 4*2
 | 
			
		||||
        output_msb = output_msb + output_lsb
 | 
			
		||||
        x = output_msb
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x        
 | 
			
		||||
    
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stage1_3H size: {self.stage1_3H.shape}" + \
 | 
			
		||||
               f"\n  stage1_3D size: {self.stage1_3D.shape}" + \
 | 
			
		||||
               f"\n  stage1_3B size: {self.stage1_3B.shape}" + \
 | 
			
		||||
               f"\n  stage1_2H size: {self.stage1_2H.shape}" + \
 | 
			
		||||
               f"\n  stage1_2D size: {self.stage1_2D.shape}" + \
 | 
			
		||||
               f"\n  stage2_3H size: {self.stage2_3H.shape}" + \
 | 
			
		||||
               f"\n  stage2_3D size: {self.stage2_3D.shape}" + \
 | 
			
		||||
               f"\n  stage2_3B size: {self.stage2_3B.shape}" + \
 | 
			
		||||
               f"\n  stage2_2H size: {self.stage2_2H.shape}" + \
 | 
			
		||||
               f"\n  stage2_2D size: {self.stage2_2D.shape}"
 | 
			
		||||
@ -1,441 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
from common import lut
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from . import hdblut
 | 
			
		||||
from common import layers
 | 
			
		||||
from itertools import cycle
 | 
			
		||||
from models.base import SRNetBase
 | 
			
		||||
 | 
			
		||||
class HDBNet(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4, rotations = 4):
 | 
			
		||||
        super(HDBNet, self).__init__()
 | 
			
		||||
        assert scale == 4
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.rotations = rotations
 | 
			
		||||
        self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
        self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
 | 
			
		||||
        self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
        self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)        
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb     
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(self.rotations):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
 | 
			
		||||
                        self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
 | 
			
		||||
                        self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
 | 
			
		||||
           output_msb /= 3
 | 
			
		||||
           output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
 | 
			
		||||
                        self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
 | 
			
		||||
           output_lsb /= 2
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) 
 | 
			
		||||
           output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
 | 
			
		||||
        output /= self.rotations
 | 
			
		||||
        x = output
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*4, w*4], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(self.rotations):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
 | 
			
		||||
                        self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
 | 
			
		||||
                        self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
 | 
			
		||||
           output_msb /= 3
 | 
			
		||||
           output_lsb = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
 | 
			
		||||
                        self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
 | 
			
		||||
           output_lsb /= 2
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) 
 | 
			
		||||
           output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
 | 
			
		||||
        output /= self.rotations
 | 
			
		||||
        x = output
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
        stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        lut_model = hdblut.HDBLut.init_from_numpy(
 | 
			
		||||
            stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, 
 | 
			
		||||
            stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
 | 
			
		||||
        )
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class HDBNetv2(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 2, scale = 4, rotations = 4):
 | 
			
		||||
        super(HDBNetv2, self).__init__()
 | 
			
		||||
        assert scale == 4
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.rotations = rotations
 | 
			
		||||
        self.layers_count = layers_count
 | 
			
		||||
        self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage1_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
        self.stage1_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
 | 
			
		||||
        self.stage2_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=255, output_max_value=255)
 | 
			
		||||
        self.stage2_2H = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
        self.stage2_2D = layers.UpscaleBlock(in_features=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=2, input_max_value=15, output_max_value=255)
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2)        
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb     
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*2, w*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(self.rotations):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage1_3H) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage1_3D) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage1_3B)
 | 
			
		||||
           output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage1_2H) + \
 | 
			
		||||
                         self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage1_2D)
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) 
 | 
			
		||||
           output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= self.rotations*3
 | 
			
		||||
        output_lsb /= self.rotations*2
 | 
			
		||||
        output = output_msb + output_lsb
 | 
			
		||||
        x = output.clamp(0, 255)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*2*2, w*2*2], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(self.rotations):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msbt = self.forward_stage(rotated_msb, 2, self._extract_pattern_3H, self.stage2_3H) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, 2, self._extract_pattern_3D, self.stage2_3D) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, 2, self._extract_pattern_3B, self.stage2_3B)
 | 
			
		||||
           output_lsbt = self.forward_stage(rotated_lsb, 2, self._extract_pattern_2H, self.stage2_2H) + \
 | 
			
		||||
                         self.forward_stage(rotated_lsb, 2, self._extract_pattern_2D, self.stage2_2D)
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s2_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s2_output_msb', output_msb.detach().cpu().numpy(), config.current_iter) 
 | 
			
		||||
           output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= self.rotations*3
 | 
			
		||||
        output_lsb /= self.rotations*2
 | 
			
		||||
        output = output_msb + output_lsb
 | 
			
		||||
        x = output.clamp(0, 255)
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
        stage1_3H = lut.transfer_3_input_SxS_output(self.stage1_3H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_3D = lut.transfer_3_input_SxS_output(self.stage1_3D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_3B = lut.transfer_3_input_SxS_output(self.stage1_3B, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_2H = lut.transfer_2_input_SxS_output(self.stage1_2H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage1_2D = lut.transfer_2_input_SxS_output(self.stage1_2D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        stage2_3H = lut.transfer_3_input_SxS_output(self.stage2_3H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_3D = lut.transfer_3_input_SxS_output(self.stage2_3D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_3B = lut.transfer_3_input_SxS_output(self.stage2_3B, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_2H = lut.transfer_2_input_SxS_output(self.stage2_2H, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        stage2_2D = lut.transfer_2_input_SxS_output(self.stage2_2D, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        lut_model = hdblut.HDBLut.init_from_numpy(
 | 
			
		||||
            stage1_3H, stage1_3D, stage1_3B, stage1_2H, stage1_2D, 
 | 
			
		||||
            stage2_3H, stage2_3D, stage2_3B, stage2_2H, stage2_2D
 | 
			
		||||
        )
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class HDBLNet(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
        super(HDBLNet, self).__init__()
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb     
 | 
			
		||||
        output_msb = self.forward_stage(msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
 | 
			
		||||
                     self.forward_stage(msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
 | 
			
		||||
                     self.forward_stage(msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
 | 
			
		||||
        output_lsb = self.forward_stage(lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
 | 
			
		||||
        output_msb /= 3
 | 
			
		||||
        if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
            config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
            config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)            
 | 
			
		||||
        output = output_msb + output_lsb
 | 
			
		||||
        x = output.clamp(0, 255)
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
class HDBLNetR90(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
        super(HDBLNet, self).__init__()
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.stage1_3H = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3D = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3B = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3L = layers.UpscaleBlock(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb     
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(4):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
 | 
			
		||||
           output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
           output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= 4*3
 | 
			
		||||
        output_lsb /= 4
 | 
			
		||||
        output = output_msb + output_lsb
 | 
			
		||||
        x = output.clamp(0, 255)
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HDBLNetR90KAN(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
        super(HDBLNetR90KAN, self).__init__()
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.stage1_3H = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3D = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3B = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
        self.stage1_3L = layers.UpscaleBlockChebyKAN(in_features=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=self.scale)
 | 
			
		||||
 | 
			
		||||
        self._extract_pattern_3H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[0,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_3L = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb     
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count in range(4):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msbt = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3H, self.stage1_3H) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3D, self.stage1_3D) + \
 | 
			
		||||
                         self.forward_stage(rotated_msb, self.scale, self._extract_pattern_3B, self.stage1_3B)
 | 
			
		||||
           output_lsbt = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_3L, self.stage1_3L)
 | 
			
		||||
           if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
                config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
                config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
           output_msb += torch.rot90(output_msbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(output_lsbt, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= 4*3
 | 
			
		||||
        output_lsb /= 4
 | 
			
		||||
        output = output_msb + output_lsb
 | 
			
		||||
        x = output.clamp(0, 255)
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HDBHNet(SRNetBase):
 | 
			
		||||
    def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
        super(HDBHNet, self).__init__()
 | 
			
		||||
        self.scale = scale 
 | 
			
		||||
        self.hidden_dim = hidden_dim
 | 
			
		||||
        self.layers_count = layers_count
 | 
			
		||||
 | 
			
		||||
        self.msb_fns = SRNetBaseList([layers.UpscaleBlock(
 | 
			
		||||
            in_features=4,
 | 
			
		||||
            hidden_dim=hidden_dim,
 | 
			
		||||
            layers_count=layers_count,
 | 
			
		||||
            upscale_factor=self.scale
 | 
			
		||||
        ) for x in range(1)])
 | 
			
		||||
        self.lsb_fns = SRNetBaseList([layers.UpscaleBlock(
 | 
			
		||||
            in_features=4,
 | 
			
		||||
            hidden_dim=hidden_dim,
 | 
			
		||||
            layers_count=layers_count,
 | 
			
		||||
            upscale_factor=self.scale
 | 
			
		||||
        ) for x in range(1)])
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, stage):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = stage(x)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
 | 
			
		||||
        lsb = x % 16
 | 
			
		||||
        msb = x - lsb
 | 
			
		||||
 | 
			
		||||
        output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        for rotations_count, msb_fn, lsb_fn in zip(range(4), cycle(self.msb_fns), cycle(self.lsb_fns)):
 | 
			
		||||
           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		||||
           output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, msb_fn)           
 | 
			
		||||
           output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, lsb_fn)
 | 
			
		||||
           output_msb_r = round_func((output_msb_r / 255)*16) * 15 
 | 
			
		||||
           output_lsb_r = (output_lsb_r / 255) * 15  
 | 
			
		||||
           output_msb += torch.rot90(output_msb_r, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
           output_lsb += torch.rot90(output_lsb_r, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output_msb /= 4
 | 
			
		||||
        output_lsb /= 4
 | 
			
		||||
        if not config is None and config.current_iter % config.display_step == 0:
 | 
			
		||||
            config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
            config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
 | 
			
		||||
        x = output_msb + output_lsb
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
    
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        fourier_loss_fn = FocalFrequencyLoss()
 | 
			
		||||
        high_frequency_loss_fn = FourierLoss()
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            a = fourier_loss_fn(pred/255, target/255) * 1e8
 | 
			
		||||
            # b = F.mse_loss(pred/255, target/255) #* 1e3
 | 
			
		||||
            # c = high_frequency_loss_fn(pred/255, target/255) * 1e6
 | 
			
		||||
            return a #+ b #+ c
 | 
			
		||||
        return loss_fn
 | 
			
		||||
@ -0,0 +1,95 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
from common import lut
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from common import layers
 | 
			
		||||
from common import losses
 | 
			
		||||
from common.base import SRBase
 | 
			
		||||
from common.transferer import TRANSFERER
 | 
			
		||||
 | 
			
		||||
class SRNetBase(SRBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(SRNetBase, self).__init__()
 | 
			
		||||
        self.config = None
 | 
			
		||||
        self.stage1_S = layers.UpscaleBlock(None)
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, script_config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        x = self.stage1_S(x, self._extract_pattern_S)
 | 
			
		||||
        x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
class SRNet(SRNetBase):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super(SRNet, self).__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.stage1_S.stage = layers.LinearUpscaleBlockNet(hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		||||
 | 
			
		||||
class SRLut(SRNetBase):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super(SRLut, self).__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.stage1_S.stage = layers.LinearUpscaleBlockLut(quantization_interval=self.config.quantization_interval, upscale_factor=self.config.upscale_factor)
 | 
			
		||||
 | 
			
		||||
TRANSFERER.register(SRNet, SRLut)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChebyKANBase(SRBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(ChebyKANBase, self).__init__()
 | 
			
		||||
        self.config = None
 | 
			
		||||
        self.stage1_S = layers.UpscaleBlock(None)
 | 
			
		||||
        window_size = 7
 | 
			
		||||
        self._extract_pattern = layers.PercievePattern(
 | 
			
		||||
            receptive_field_idxes=[[i,j] for i in range(window_size) for j in range(window_size)], 
 | 
			
		||||
            center=[window_size//2,window_size//2], 
 | 
			
		||||
            window_size=window_size
 | 
			
		||||
        )
 | 
			
		||||
            
 | 
			
		||||
    def forward(self, x, script_config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        x = self.stage1_S(x, self._extract_pattern)
 | 
			
		||||
        x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
class ChebyKANNet(ChebyKANBase):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super(ChebyKANNet, self).__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
        window_size = 7
 | 
			
		||||
        self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
 | 
			
		||||
            in_features=window_size*window_size, 
 | 
			
		||||
            out_channels=1, 
 | 
			
		||||
            hidden_dim=16, 
 | 
			
		||||
            layers_count=self.config.layers_count, 
 | 
			
		||||
            upscale_factor=self.config.upscale_factor, 
 | 
			
		||||
            degree=8
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
class ChebyKANLut(ChebyKANBase):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super(ChebyKANLut, self).__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
        window_size = 7
 | 
			
		||||
        self.stage1_S.stage = layers.ChebyKANUpscaleBlockNet(
 | 
			
		||||
            in_features=window_size*window_size, 
 | 
			
		||||
            out_channels=1, 
 | 
			
		||||
            hidden_dim=16, 
 | 
			
		||||
            layers_count=self.config.layers_count, 
 | 
			
		||||
            upscale_factor=self.config.upscale_factor
 | 
			
		||||
        ).get_lut_model(quantization_interval=self.config.quantization_interval)
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        ssim_loss = losses.SSIM(data_range=255)
 | 
			
		||||
        l1_loss = losses.CharbonnierLoss()
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return ssim_loss(pred, target) + l1_loss(pred, target) 
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
TRANSFERER.register(ChebyKANNet, ChebyKANLut)
 | 
			
		||||
@ -1,505 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
# from common.lut import  forward_rc_conv_centered, forward_rc_conv_rot90, forward_2x2_input_SxS_output
 | 
			
		||||
# from pathlib import Path
 | 
			
		||||
 | 
			
		||||
# class RCLutCentered_3x3(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutCentered_3x3, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         rc_conv_luts, dense_conv_lut
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(dense_conv_lut.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(dense_conv_lut.shape[0]-1)
 | 
			
		||||
#         lut_model = RCLutCentered_3x3(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
#         lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
 | 
			
		||||
#         lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
#         x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate') 
 | 
			
		||||
#         x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts)
 | 
			
		||||
#         x = x[:,:,self.window_size//2:-self.window_size//2+1,self.window_size//2:-self.window_size//2+1]
 | 
			
		||||
#         x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) 
 | 
			
		||||
#         x = x.view(b, c, x.shape[-2], x.shape[-1])      
 | 
			
		||||
#         return x        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  rc_conv_luts size: {self.rc_conv_luts.shape}",
 | 
			
		||||
#             f"  dense_conv_lut size: {self.dense_conv_lut.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
# class RCLutCentered_7x7(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         window_size,
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutCentered_7x7, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         rc_conv_luts, dense_conv_lut
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(dense_conv_lut.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(dense_conv_lut.shape[0]-1)
 | 
			
		||||
#         lut_model = RCLutCentered_7x7(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
#         lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
 | 
			
		||||
#         lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
#         x = forward_rc_conv_centered(index=x, lut=self.rc_conv_luts) 
 | 
			
		||||
#         x = forward_2x2_input_SxS_output(index=x, lut=self.dense_conv_lut) 
 | 
			
		||||
#         # x = repeat(x, 'b c h w -> b c (h repeat1) (w repeat2)', repeat1=4, repeat2=4)
 | 
			
		||||
#         x = x.view(b, c, x.shape[-2], x.shape[-1])      
 | 
			
		||||
#         return x        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  rc_conv_luts size: {self.rc_conv_luts.shape}",
 | 
			
		||||
#             f"  dense_conv_lut size: {self.dense_conv_lut.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
# class RCLutRot90_3x3(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutRot90_3x3, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         rc_conv_luts, dense_conv_lut
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(dense_conv_lut.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(dense_conv_lut.shape[0]-1)
 | 
			
		||||
#         lut_model = RCLutRot90_3x3(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
#         lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
 | 
			
		||||
#         lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)        
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
 | 
			
		||||
#             rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) 
 | 
			
		||||
#             unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += unrotated_prediction
 | 
			
		||||
#         output /= 4
 | 
			
		||||
#         output = output.view(b, c, output.shape[-2], output.shape[-1])      
 | 
			
		||||
#         return output        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  rc_conv_luts size: {self.rc_conv_luts.shape}",
 | 
			
		||||
#             f"  dense_conv_lut size: {self.dense_conv_lut.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
# class RCLutRot90_7x7(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutRot90_7x7, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.dense_conv_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         rc_conv_luts, dense_conv_lut
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(dense_conv_lut.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(dense_conv_lut.shape[0]-1)
 | 
			
		||||
#         window_size = rc_conv_luts.shape[0]
 | 
			
		||||
#         lut_model = RCLutRot90_7x7(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
#         lut_model.dense_conv_lut = nn.Parameter(torch.tensor(dense_conv_lut).type(torch.float32))
 | 
			
		||||
#         lut_model.rc_conv_luts = nn.Parameter(torch.tensor(rc_conv_luts).type(torch.float32))
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)        
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             rotated = forward_rc_conv_rot90(index=rotated, lut=self.rc_conv_luts)
 | 
			
		||||
#             rotated_prediction = forward_2x2_input_SxS_output(index=rotated, lut=self.dense_conv_lut) 
 | 
			
		||||
#             unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += unrotated_prediction
 | 
			
		||||
#         output /= 4
 | 
			
		||||
#         output = output.view(b, c, output.shape[-2], output.shape[-1])      
 | 
			
		||||
#         return output        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  rc_conv_luts size: {self.rc_conv_luts.shape}",
 | 
			
		||||
#             f"  dense_conv_lut size: {self.dense_conv_lut.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
# class RCLutx1(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutx1, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         rc_conv_luts_3x3, dense_conv_lut_3x3,
 | 
			
		||||
#         rc_conv_luts_5x5, dense_conv_lut_5x5,
 | 
			
		||||
#         rc_conv_luts_7x7, dense_conv_lut_7x7
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(dense_conv_lut_3x3.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(dense_conv_lut_3x3.shape[0]-1)
 | 
			
		||||
 | 
			
		||||
#         lut_model = RCLutx1(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
 | 
			
		||||
#         lut_model.rc_conv_luts_3x3 = nn.Parameter(torch.tensor(rc_conv_luts_3x3).type(torch.float32))
 | 
			
		||||
#         lut_model.dense_conv_lut_3x3 = nn.Parameter(torch.tensor(dense_conv_lut_3x3).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.rc_conv_luts_5x5 = nn.Parameter(torch.tensor(rc_conv_luts_5x5).type(torch.float32))
 | 
			
		||||
#         lut_model.dense_conv_lut_5x5 = nn.Parameter(torch.tensor(dense_conv_lut_5x5).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.rc_conv_luts_7x7 = nn.Parameter(torch.tensor(rc_conv_luts_7x7).type(torch.float32))
 | 
			
		||||
#         lut_model.dense_conv_lut_7x7 = nn.Parameter(torch.tensor(dense_conv_lut_7x7).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
 | 
			
		||||
#         x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
 | 
			
		||||
#         x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)        
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_3x3, dense_conv_lut=self.dense_conv_lut_3x3), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_5x5, dense_conv_lut=self.dense_conv_lut_5x5), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.rc_conv_luts_7x7, dense_conv_lut=self.dense_conv_lut_7x7), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, output.shape[-2], output.shape[-1])     
 | 
			
		||||
#         return output        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  rc_conv_luts_3x3 size: {self.rc_conv_luts_3x3.shape}",
 | 
			
		||||
#             f"  dense_conv_lut_3x3 size: {self.dense_conv_lut_3x3.shape}",
 | 
			
		||||
#             f"  rc_conv_luts_5x5 size: {self.rc_conv_luts_5x5.shape}",
 | 
			
		||||
#             f"  dense_conv_lut_5x5 size: {self.dense_conv_lut_5x5.shape}",
 | 
			
		||||
#             f"  rc_conv_luts_7x7 size: {self.rc_conv_luts_7x7.shape}",
 | 
			
		||||
#             f"  dense_conv_lut_7x7 size: {self.dense_conv_lut_7x7.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class RCLutx2(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutx2, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
 | 
			
		||||
#         s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
 | 
			
		||||
#         s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
 | 
			
		||||
#         s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
 | 
			
		||||
#         s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
 | 
			
		||||
#         s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(s2_dense_conv_lut_3x3.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
 | 
			
		||||
 | 
			
		||||
#         lut_model = RCLutx2(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
 | 
			
		||||
#         x = forward_rc_conv_rot90(index=index, lut=rc_conv_lut)
 | 
			
		||||
#         x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)  
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, output.shape[-2], output.shape[-1])     
 | 
			
		||||
#         return output        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
 | 
			
		||||
#             f"  s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
 | 
			
		||||
#             f"  s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class RCLutx2Centered(nn.Module):
 | 
			
		||||
#     def __init__(
 | 
			
		||||
#         self, 
 | 
			
		||||
#         quantization_interval,
 | 
			
		||||
#         scale
 | 
			
		||||
#     ):
 | 
			
		||||
#         super(RCLutx2Centered, self).__init__()     
 | 
			
		||||
#         self.scale = scale
 | 
			
		||||
#         self.quantization_interval = quantization_interval
 | 
			
		||||
#         self.s1_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s1_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_3x3 = nn.Parameter(torch.randint(0, 255, size=(3, 3, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_5x5 = nn.Parameter(torch.randint(0, 255, size=(5, 5, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_rc_conv_luts_7x7 = nn.Parameter(torch.randint(0, 255, size=(7, 7, 256//quantization_interval+1)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_3x3 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_5x5 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
#         self.s2_dense_conv_lut_7x7 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#     @staticmethod
 | 
			
		||||
#     def init_from_numpy(
 | 
			
		||||
#         s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3,
 | 
			
		||||
#         s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5,
 | 
			
		||||
#         s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7,
 | 
			
		||||
#         s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3,
 | 
			
		||||
#         s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5,
 | 
			
		||||
#         s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7
 | 
			
		||||
#     ):   
 | 
			
		||||
#         scale = int(s2_dense_conv_lut_3x3.shape[-1])
 | 
			
		||||
#         quantization_interval = 256//(s2_dense_conv_lut_3x3.shape[0]-1)
 | 
			
		||||
 | 
			
		||||
#         lut_model = RCLutx2Centered(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s1_rc_conv_luts_3x3).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s1_dense_conv_lut_3x3).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s1_rc_conv_luts_5x5).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s1_dense_conv_lut_5x5).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s1_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s1_rc_conv_luts_7x7).type(torch.float32))
 | 
			
		||||
#         lut_model.s1_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s1_dense_conv_lut_7x7).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_3x3 = nn.Parameter(torch.tensor(s2_rc_conv_luts_3x3).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_3x3 = nn.Parameter(torch.tensor(s2_dense_conv_lut_3x3).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_5x5 = nn.Parameter(torch.tensor(s2_rc_conv_luts_5x5).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_5x5 = nn.Parameter(torch.tensor(s2_dense_conv_lut_5x5).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         lut_model.s2_rc_conv_luts_7x7 = nn.Parameter(torch.tensor(s2_rc_conv_luts_7x7).type(torch.float32))
 | 
			
		||||
#         lut_model.s2_dense_conv_lut_7x7 = nn.Parameter(torch.tensor(s2_dense_conv_lut_7x7).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def _forward_rcblock(self, index, rc_conv_lut, dense_conv_lut):
 | 
			
		||||
#         x = forward_rc_conv_centered(index=index, lut=rc_conv_lut)
 | 
			
		||||
#         x = forward_2x2_input_SxS_output(index=x, lut=dense_conv_lut)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w).type(torch.float32)  
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_3x3, dense_conv_lut=self.s1_dense_conv_lut_3x3), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_5x5, dense_conv_lut=self.s1_dense_conv_lut_5x5), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s1_rc_conv_luts_7x7, dense_conv_lut=self.s1_dense_conv_lut_7x7), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_3x3, dense_conv_lut=self.s2_dense_conv_lut_3x3), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_5x5, dense_conv_lut=self.s2_dense_conv_lut_5x5), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#             output += torch.rot90(
 | 
			
		||||
#                 self._forward_rcblock(index=rotated, rc_conv_lut=self.s2_rc_conv_luts_7x7, dense_conv_lut=self.s2_dense_conv_lut_7x7), 
 | 
			
		||||
#                 k=-rotations_count, 
 | 
			
		||||
#                 dims=[2, 3]
 | 
			
		||||
#             )
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, output.shape[-2], output.shape[-1])     
 | 
			
		||||
#         return output        
 | 
			
		||||
    
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return "\n".join([
 | 
			
		||||
#             f"{self.__class__.__name__}(",
 | 
			
		||||
#             f"  s1_rc_conv_luts_3x3 size: {self.s1_rc_conv_luts_3x3.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_3x3 size: {self.s1_dense_conv_lut_3x3.shape}",
 | 
			
		||||
#             f"  s1_rc_conv_luts_5x5 size: {self.s1_rc_conv_luts_5x5.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_5x5 size: {self.s1_dense_conv_lut_5x5.shape}",
 | 
			
		||||
#             f"  s1_rc_conv_luts_7x7 size: {self.s1_rc_conv_luts_7x7.shape}",
 | 
			
		||||
#             f"  s1_dense_conv_lut_7x7 size: {self.s1_dense_conv_lut_7x7.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_3x3 size: {self.s2_rc_conv_luts_3x3.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_3x3 size: {self.s2_dense_conv_lut_3x3.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_5x5 size: {self.s2_rc_conv_luts_5x5.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_5x5 size: {self.s2_dense_conv_lut_5x5.shape}",
 | 
			
		||||
#             f"  s2_rc_conv_luts_7x7 size: {self.s2_rc_conv_luts_7x7.shape}",
 | 
			
		||||
#             f"  s2_dense_conv_lut_7x7 size: {self.s2_dense_conv_lut_7x7.shape}",
 | 
			
		||||
#             ")"])
 | 
			
		||||
@ -1,568 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from common import lut
 | 
			
		||||
from . import rclut
 | 
			
		||||
from common import layers 
 | 
			
		||||
 | 
			
		||||
# class ReconstructedConvCentered(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim,  window_size=7):
 | 
			
		||||
#         super(ReconstructedConvCentered, self).__init__()
 | 
			
		||||
#         self.window_size = window_size
 | 
			
		||||
#         self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
#         self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
 | 
			
		||||
#     def pixel_wise_forward(self, x):
 | 
			
		||||
#         x = (x-127.5)/127.5
 | 
			
		||||
#         out = torch.einsum('bwk,wh,wh -> bwk', x, self.projection1, self.projection2) 
 | 
			
		||||
#         out = torch.tanh(out)
 | 
			
		||||
#         out = out*127.5 + 127.5
 | 
			
		||||
#         return out 
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         original_shape = x.shape  
 | 
			
		||||
#         x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
 | 
			
		||||
#         x = F.unfold(x, self.window_size)
 | 
			
		||||
#         x = self.pixel_wise_forward(x)     
 | 
			
		||||
#         x = x.mean(1)  
 | 
			
		||||
#         x = x.reshape(*original_shape)    
 | 
			
		||||
#         x = round_func(x) 
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
 | 
			
		||||
 | 
			
		||||
# class RCBlockCentered(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
 | 
			
		||||
#         super(RCBlockCentered, self).__init__()   
 | 
			
		||||
#         self.window_size = window_size    
 | 
			
		||||
#         self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
 | 
			
		||||
#         self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,hs,ws = x.shape
 | 
			
		||||
#         x = self.rc_conv(x)
 | 
			
		||||
#         x = F.pad(x, pad=[0,1,0,1], mode='replicate') 
 | 
			
		||||
#         x = self.dense_conv_block(x)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
# class RCNetCentered_3x3(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetCentered_3x3, self).__init__()
 | 
			
		||||
#         self.hidden_dim = hidden_dim
 | 
			
		||||
#         self.layers_count = layers_count
 | 
			
		||||
#         self.scale = scale       
 | 
			
		||||
#         self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
            
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         x = self.stage(x)
 | 
			
		||||
#         x = x.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         window_size = self.stage.rc_conv.window_size
 | 
			
		||||
#         rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
 | 
			
		||||
#         dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
#         lut_model = rclut.RCLutCentered_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
# class RCNetCentered_7x7(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetCentered_7x7, self).__init__()
 | 
			
		||||
#         self.hidden_dim = hidden_dim
 | 
			
		||||
#         self.layers_count = layers_count
 | 
			
		||||
#         self.scale = scale       
 | 
			
		||||
#         window_size = 7
 | 
			
		||||
#         self.stage = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=window_size)
 | 
			
		||||
            
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         x = self.stage(x)
 | 
			
		||||
#         x = x.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         window_size = self.stage.rc_conv.window_size
 | 
			
		||||
#         rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
 | 
			
		||||
#         dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
#         lut_model = rclut.RCLutCentered_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class ReconstructedConvRot90(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim,  window_size=7):
 | 
			
		||||
#         super(ReconstructedConvRot90, self).__init__()
 | 
			
		||||
#         self.window_size = window_size
 | 
			
		||||
#         self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
#         self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
 | 
			
		||||
#     def pixel_wise_forward(self, x):
 | 
			
		||||
#         x = (x-127.5)/127.5
 | 
			
		||||
#         out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) 
 | 
			
		||||
#         out = torch.tanh(out)
 | 
			
		||||
#         out = out*127.5 + 127.5
 | 
			
		||||
#         return out 
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         original_shape = x.shape
 | 
			
		||||
#         x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
 | 
			
		||||
#         x = F.unfold(x, self.window_size)
 | 
			
		||||
#         x = self.pixel_wise_forward(x)        
 | 
			
		||||
#         x = x.mean(1)  
 | 
			
		||||
#         x = x.reshape(*original_shape)  
 | 
			
		||||
#         x = round_func(x) # quality likely suffer from this
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
 | 
			
		||||
 | 
			
		||||
# class RCBlockRot90(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
 | 
			
		||||
#         super(RCBlockRot90, self).__init__()   
 | 
			
		||||
#         self.window_size = window_size    
 | 
			
		||||
#         self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
 | 
			
		||||
#         self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,hs,ws = x.shape
 | 
			
		||||
#         x = self.rc_conv(x)
 | 
			
		||||
#         x = F.pad(x, pad=[0,1,0,1], mode='replicate') 
 | 
			
		||||
#         x = self.dense_conv_block(x)
 | 
			
		||||
        
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
# class RCNetRot90_3x3(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetRot90_3x3, self).__init__()
 | 
			
		||||
#         self.hidden_dim = hidden_dim
 | 
			
		||||
#         self.layers_count = layers_count
 | 
			
		||||
#         self.scale = scale       
 | 
			
		||||
#         self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         window_size = self.stage.rc_conv.window_size
 | 
			
		||||
#         rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
 | 
			
		||||
#         dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
#         lut_model = rclut.RCLutRot90_3x3.init_from_numpy(rc_conv_luts, dense_conv_lut)
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             rotated_prediction = self.stage(rotated)
 | 
			
		||||
#             unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += unrotated_prediction
 | 
			
		||||
#         output /= 4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
# class RCNetRot90_7x7(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetRot90_7x7, self).__init__()
 | 
			
		||||
#         self.hidden_dim = hidden_dim
 | 
			
		||||
#         self.layers_count = layers_count
 | 
			
		||||
#         self.scale = scale       
 | 
			
		||||
#         self.stage = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         window_size = self.stage.rc_conv.window_size
 | 
			
		||||
#         rc_conv_luts = lut.transfer_rc_conv(self.stage.rc_conv, quantization_interval=quantization_interval).reshape(window_size,window_size,-1)
 | 
			
		||||
#         dense_conv_lut = lut.transfer_2x2_input_SxS_output(self.stage.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
#         lut_model = rclut.RCLutRot90_7x7.init_from_numpy(rc_conv_luts, dense_conv_lut)
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             rotated_prediction = self.stage(rotated)
 | 
			
		||||
#             unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += unrotated_prediction
 | 
			
		||||
#         output /= 4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
# class RCNetx1(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetx1, self).__init__()
 | 
			
		||||
#         self.scale = scale  
 | 
			
		||||
#         self.hidden_dim = hidden_dim     
 | 
			
		||||
#         self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
#         self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
 | 
			
		||||
#         self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        
 | 
			
		||||
#         lut_model = rclut.RCLutx1.init_from_numpy(
 | 
			
		||||
#             rc_conv_luts_3x3=rc_conv_luts_3x3, dense_conv_lut_3x3=dense_conv_lut_3x3,
 | 
			
		||||
#             rc_conv_luts_5x5=rc_conv_luts_5x5, dense_conv_lut_5x5=dense_conv_lut_5x5,
 | 
			
		||||
#             rc_conv_luts_7x7=rc_conv_luts_7x7, dense_conv_lut_7x7=dense_conv_lut_7x7
 | 
			
		||||
#         )
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class RCNetx2(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetx2, self).__init__()
 | 
			
		||||
#         self.scale = scale  
 | 
			
		||||
#         self.hidden_dim = hidden_dim     
 | 
			
		||||
#         self.stage1_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
 | 
			
		||||
#         self.stage1_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
 | 
			
		||||
#         self.stage1_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
 | 
			
		||||
#         self.stage2_3x3 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
#         self.stage2_5x5 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
 | 
			
		||||
#         self.stage2_7x7 = RCBlockRot90(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        
 | 
			
		||||
#         lut_model = rclut.RCLutx2.init_from_numpy(
 | 
			
		||||
#             s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
 | 
			
		||||
#             s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
 | 
			
		||||
#             s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
 | 
			
		||||
#             s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
 | 
			
		||||
#             s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
 | 
			
		||||
#             s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
 | 
			
		||||
#         )
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
# class RCNetx2Centered(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetx2Centered, self).__init__()
 | 
			
		||||
#         self.scale = scale  
 | 
			
		||||
#         self.hidden_dim = hidden_dim     
 | 
			
		||||
#         self.stage1_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
 | 
			
		||||
#         self.stage1_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
 | 
			
		||||
#         self.stage1_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
 | 
			
		||||
#         self.stage2_3x3 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
#         self.stage2_5x5 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
 | 
			
		||||
#         self.stage2_7x7 = RCBlockCentered(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        
 | 
			
		||||
#         lut_model = rclut.RCLutx2Centered.init_from_numpy(
 | 
			
		||||
#             s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
 | 
			
		||||
#             s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
 | 
			
		||||
#             s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
 | 
			
		||||
#             s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
 | 
			
		||||
#             s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
 | 
			
		||||
#             s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
 | 
			
		||||
#         )
 | 
			
		||||
#         return lut_model
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
# class ReconstructedConvRot90Unlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim,  window_size=7):
 | 
			
		||||
#         super(ReconstructedConvRot90Unlutable, self).__init__()
 | 
			
		||||
#         self.window_size = window_size
 | 
			
		||||
#         self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
#         self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
 | 
			
		||||
#     def pixel_wise_forward(self, x):
 | 
			
		||||
#         x = (x-127.5)/127.5
 | 
			
		||||
#         out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) 
 | 
			
		||||
#         out = torch.tanh(out)
 | 
			
		||||
#         out = out*127.5 + 127.5
 | 
			
		||||
#         return out 
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         original_shape = x.shape
 | 
			
		||||
#         x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
 | 
			
		||||
#         x = F.unfold(x, self.window_size)
 | 
			
		||||
#         x = self.pixel_wise_forward(x)        
 | 
			
		||||
#         x = x.mean(1)  
 | 
			
		||||
#         x = x.reshape(*original_shape)  
 | 
			
		||||
#         # x = round_func(x) # quality likely suffer from this
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
 | 
			
		||||
 | 
			
		||||
# class RCBlockRot90Unlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
 | 
			
		||||
#         super(RCBlockRot90Unlutable, self).__init__()   
 | 
			
		||||
#         self.window_size = window_size    
 | 
			
		||||
#         self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
 | 
			
		||||
#         self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,hs,ws = x.shape
 | 
			
		||||
#         x = self.rc_conv(x)
 | 
			
		||||
#         x = F.pad(x, pad=[0,1,0,1], mode='replicate') 
 | 
			
		||||
#         x = self.dense_conv_block(x)        
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
# class RCNetx2Unlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetx2Unlutable, self).__init__()
 | 
			
		||||
#         self.scale = scale  
 | 
			
		||||
#         self.hidden_dim = hidden_dim     
 | 
			
		||||
#         self.stage1_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
 | 
			
		||||
#         self.stage1_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
 | 
			
		||||
#         self.stage1_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
 | 
			
		||||
#         self.stage2_3x3 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
#         self.stage2_5x5 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
 | 
			
		||||
#         self.stage2_7x7 = RCBlockRot90Unlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        
 | 
			
		||||
#         lut_model = rclut.RCLutx2.init_from_numpy(
 | 
			
		||||
#             s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
 | 
			
		||||
#             s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
 | 
			
		||||
#             s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
 | 
			
		||||
#             s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
 | 
			
		||||
#             s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
 | 
			
		||||
#             s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
 | 
			
		||||
#         )
 | 
			
		||||
#         return lut_model       
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class ReconstructedConvCenteredUnlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim,  window_size=7):
 | 
			
		||||
#         super(ReconstructedConvCenteredUnlutable, self).__init__()
 | 
			
		||||
#         self.window_size = window_size
 | 
			
		||||
#         self.projection1 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
#         self.projection2 = torch.nn.Parameter(torch.rand((window_size**2, hidden_dim))/window_size)
 | 
			
		||||
 | 
			
		||||
#     def pixel_wise_forward(self, x):
 | 
			
		||||
#         x = (x-127.5)/127.5
 | 
			
		||||
#         out = torch.einsum('bik,ij,ij -> bik', x, self.projection1, self.projection2) 
 | 
			
		||||
#         out = torch.tanh(out)
 | 
			
		||||
#         out = out*127.5 + 127.5
 | 
			
		||||
#         return out 
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         original_shape = x.shape
 | 
			
		||||
#         x = F.pad(x, pad=[self.window_size//2]*4, mode='replicate')
 | 
			
		||||
#         x = F.unfold(x, self.window_size)
 | 
			
		||||
#         x = self.pixel_wise_forward(x)        
 | 
			
		||||
#         x = x.mean(1)  
 | 
			
		||||
#         x = x.reshape(*original_shape)  
 | 
			
		||||
#         # x = round_func(x) # quality likely suffer from this
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
#     def __repr__(self):
 | 
			
		||||
#         return f"{self.__class__.__name__} projection1: {self.projection1.shape} projection2: {self.projection2.shape}"
 | 
			
		||||
 | 
			
		||||
# class RCBlockCenteredUnlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 32, window_size=3, dense_conv_layer_count=4, upscale_factor=4):
 | 
			
		||||
#         super(RCBlockRot90Unlutable, self).__init__()   
 | 
			
		||||
#         self.window_size = window_size    
 | 
			
		||||
#         self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
 | 
			
		||||
#         self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,hs,ws = x.shape
 | 
			
		||||
#         x = self.rc_conv(x)
 | 
			
		||||
#         x = F.pad(x, pad=[0,1,0,1], mode='replicate') 
 | 
			
		||||
#         x = self.dense_conv_block(x)        
 | 
			
		||||
#         return x
 | 
			
		||||
 | 
			
		||||
# class RCNetx2CenteredUnlutable(nn.Module):
 | 
			
		||||
#     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
 | 
			
		||||
#         super(RCNetx2CenteredUnlutable, self).__init__()
 | 
			
		||||
#         self.scale = scale  
 | 
			
		||||
#         self.hidden_dim = hidden_dim     
 | 
			
		||||
#         self.stage1_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=3)
 | 
			
		||||
#         self.stage1_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=5)
 | 
			
		||||
#         self.stage1_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=1, window_size=7)
 | 
			
		||||
#         self.stage2_3x3 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=3)
 | 
			
		||||
#         self.stage2_5x5 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=5)
 | 
			
		||||
#         self.stage2_7x7 = RCBlockCenteredUnlutable(hidden_dim=hidden_dim, dense_conv_layer_count=layers_count, upscale_factor=scale, window_size=7)
 | 
			
		||||
 | 
			
		||||
#     def get_lut_model(self, quantization_interval=16, batch_size=2**10):
 | 
			
		||||
#         s1_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage1_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s1_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage1_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage1_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s1_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage1_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s1_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage1_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s1_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage1_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_3x3 = lut.transfer_rc_conv(self.stage2_3x3.rc_conv, quantization_interval=quantization_interval).reshape(3,3,-1)
 | 
			
		||||
#         s2_dense_conv_lut_3x3 = lut.transfer_2x2_input_SxS_output(self.stage2_3x3.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_5x5 = lut.transfer_rc_conv(self.stage2_5x5.rc_conv, quantization_interval=quantization_interval).reshape(5,5,-1)
 | 
			
		||||
#         s2_dense_conv_lut_5x5 = lut.transfer_2x2_input_SxS_output(self.stage2_5x5.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
#         s2_rc_conv_luts_7x7 = lut.transfer_rc_conv(self.stage2_7x7.rc_conv, quantization_interval=quantization_interval).reshape(7,7,-1)
 | 
			
		||||
#         s2_dense_conv_lut_7x7 = lut.transfer_2x2_input_SxS_output(self.stage2_7x7.dense_conv_block, quantization_interval=quantization_interval, batch_size=batch_size)
 | 
			
		||||
        
 | 
			
		||||
#         lut_model = rclut.RCLutx2Centered.init_from_numpy(
 | 
			
		||||
#             s1_rc_conv_luts_3x3=s1_rc_conv_luts_3x3, s1_dense_conv_lut_3x3=s1_dense_conv_lut_3x3,
 | 
			
		||||
#             s1_rc_conv_luts_5x5=s1_rc_conv_luts_5x5, s1_dense_conv_lut_5x5=s1_dense_conv_lut_5x5,
 | 
			
		||||
#             s1_rc_conv_luts_7x7=s1_rc_conv_luts_7x7, s1_dense_conv_lut_7x7=s1_dense_conv_lut_7x7,
 | 
			
		||||
#             s2_rc_conv_luts_3x3=s2_rc_conv_luts_3x3, s2_dense_conv_lut_3x3=s2_dense_conv_lut_3x3,
 | 
			
		||||
#             s2_rc_conv_luts_5x5=s2_rc_conv_luts_5x5, s2_dense_conv_lut_5x5=s2_dense_conv_lut_5x5,
 | 
			
		||||
#             s2_rc_conv_luts_7x7=s2_rc_conv_luts_7x7, s2_dense_conv_lut_7x7=s2_dense_conv_lut_7x7
 | 
			
		||||
#         )
 | 
			
		||||
#         return lut_model      
 | 
			
		||||
 | 
			
		||||
#     def forward(self, x):
 | 
			
		||||
#         b,c,h,w = x.shape
 | 
			
		||||
#         x = x.view(b*c, 1, h, w)
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage1_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         x = output
 | 
			
		||||
#         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
#         for rotations_count in range(4):
 | 
			
		||||
#             rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_3x3(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_5x5(rotated), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
#             output += torch.rot90(self.stage2_7x7(rotated), k=-rotations_count, dims=[2, 3])            
 | 
			
		||||
#         output /= 3*4
 | 
			
		||||
#         output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
#         return output
 | 
			
		||||
 | 
			
		||||
@ -1,395 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from common.lut import select_index_4dlut_tetrahedral
 | 
			
		||||
from common.layers import PercievePattern
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
 | 
			
		||||
class SDYLutx1(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SDYLutx1, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stageS, stageD, stageY
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stageS.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stageS.shape[0]-1)
 | 
			
		||||
        lut_model = SDYLutx1(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
 | 
			
		||||
        lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
 | 
			
		||||
        lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS) 
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stageS size: {self.stageS.shape}" + \
 | 
			
		||||
               f"\n  stageD size: {self.stageD.shape}" + \
 | 
			
		||||
               f"\n  stageY size: {self.stageY.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class SDYLutx2(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SDYLutx2, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage2_S.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage2_S.shape[0]-1)
 | 
			
		||||
        lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) 
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        x = output
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S) 
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stage1_S size: {self.stage1_S.shape}" + \
 | 
			
		||||
               f"\n  stage1_D size: {self.stage1_D.shape}" + \
 | 
			
		||||
               f"\n  stage1_Y size: {self.stage1_Y.shape}" + \
 | 
			
		||||
               f"\n  stage2_S size: {self.stage2_S.shape}" + \
 | 
			
		||||
               f"\n  stage2_D size: {self.stage2_D.shape}" + \
 | 
			
		||||
               f"\n  stage2_Y size: {self.stage2_Y.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SDYLutx3(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SDYLutx3, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage3_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage3_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage3_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y, stage3_S, stage3_D, stage3_Y
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage3_S.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage3_S.shape[0]-1)
 | 
			
		||||
        lut_model = SDYLutx3(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
 | 
			
		||||
        lut_model.stage3_S = nn.Parameter(torch.tensor(stage3_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage3_D = nn.Parameter(torch.tensor(stage3_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage3_Y = nn.Parameter(torch.tensor(stage3_Y).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) 
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        x = output
 | 
			
		||||
        output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage2_S) 
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage2_D)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage2_Y)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        x = output
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage3_S) 
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage3_D)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage3_Y)
 | 
			
		||||
        output /= 3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stage1_S size: {self.stage1_S.shape}" + \
 | 
			
		||||
               f"\n  stage1_D size: {self.stage1_D.shape}" + \
 | 
			
		||||
               f"\n  stage1_Y size: {self.stage1_Y.shape}" + \
 | 
			
		||||
               f"\n  stage2_S size: {self.stage2_S.shape}" + \
 | 
			
		||||
               f"\n  stage2_D size: {self.stage2_D.shape}" + \
 | 
			
		||||
               f"\n  stage2_Y size: {self.stage2_Y.shape}" + \
 | 
			
		||||
               f"\n  stage3_S size: {self.stage3_S.shape}" + \
 | 
			
		||||
               f"\n  stage3_D size: {self.stage3_D.shape}" + \
 | 
			
		||||
               f"\n  stage3_Y size: {self.stage3_Y.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SDYLutR90x1(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SDYLutR90x1, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stageS, stageD, stageY
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stageS.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stageS.shape[0]-1)
 | 
			
		||||
        lut_model = SDYLutR90x1(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32))
 | 
			
		||||
        lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32))
 | 
			
		||||
        lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stageS) 
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stageD)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stageY)
 | 
			
		||||
        for rotations_count in range(1, 4):
 | 
			
		||||
            rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stageS), k=-rotations_count, dims=[-2, -1]) 
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stageD), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stageY), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
        output /= 4*3
 | 
			
		||||
        output = round_func(output)
 | 
			
		||||
        output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stageS size: {self.stageS.shape}" + \
 | 
			
		||||
               f"\n  stageD size: {self.stageD.shape}" + \
 | 
			
		||||
               f"\n  stageY size: {self.stageY.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SDYLutR90x2(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SDYLutR90x2, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
 | 
			
		||||
        self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		||||
        self.stage1_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage1_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (1,1)).type(torch.float32))
 | 
			
		||||
        self.stage2_S = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage2_D = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self.stage2_Y = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage2_S.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage2_S.shape[0]-1)
 | 
			
		||||
        lut_model = SDYLutR90x2(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
 | 
			
		||||
        lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.view(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) 
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) 
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D)
 | 
			
		||||
        output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y)
 | 
			
		||||
        for rotations_count in range(1, 4):
 | 
			
		||||
            rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_S, self.stage1_S), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_D, self.stage1_D), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, 1, self._extract_pattern_Y, self.stage1_Y), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
        output /= 4*3
 | 
			
		||||
        x = round_func(output)
 | 
			
		||||
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y)
 | 
			
		||||
        for rotations_count in range(1, 4):
 | 
			
		||||
            rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage2_S), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_D, self.stage2_D), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_Y, self.stage2_Y), k=-rotations_count, dims=[-2, -1])
 | 
			
		||||
        output /= 4*3 
 | 
			
		||||
        output = round_func(output) 
 | 
			
		||||
        output = output.view(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}" + \
 | 
			
		||||
               f"\n  stage1_S size: {self.stage1_S.shape}" + \
 | 
			
		||||
               f"\n  stage1_D size: {self.stage1_D.shape}" + \
 | 
			
		||||
               f"\n  stage1_Y size: {self.stage1_Y.shape}" + \
 | 
			
		||||
               f"\n  stage2_S size: {self.stage2_S.shape}" + \
 | 
			
		||||
               f"\n  stage2_D size: {self.stage2_D.shape}" + \
 | 
			
		||||
               f"\n  stage2_Y size: {self.stage2_Y.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -1,276 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from common.lut import select_index_4dlut_tetrahedral
 | 
			
		||||
from common import layers 
 | 
			
		||||
from common.utils import round_func
 | 
			
		||||
from models.base import SRNetBase
 | 
			
		||||
 | 
			
		||||
class SRLut(SRNetBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SRLut, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage_lut
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage_lut.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage_lut.shape[0]-1)
 | 
			
		||||
        lut_model = SRLut(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w).type(torch.float32)
 | 
			
		||||
        x = self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
 | 
			
		||||
        x = x.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return x        
 | 
			
		||||
    
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}\n  lut size: {self.stage_lut.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class SRLutY(SRNetBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SRLutY, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self.rgb_to_ycbcr = layers.RgbToYcbcr()
 | 
			
		||||
        self.ycbcr_to_rgb = layers.YcbcrToRgb()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy( 
 | 
			
		||||
        stage_lut
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage_lut.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage_lut.shape[0]-1)
 | 
			
		||||
        lut_model = SRLutY(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
    
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = self.rgb_to_ycbcr(x)
 | 
			
		||||
        y = x[:,0:1,:,:]
 | 
			
		||||
        cbcr = x[:,1:,:,:]
 | 
			
		||||
        cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
 | 
			
		||||
        
 | 
			
		||||
        output = self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
 | 
			
		||||
        output = torch.cat([output, cbcr_scaled], dim=1)  
 | 
			
		||||
        output = self.ycbcr_to_rgb(output).clamp(0, 255)        
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}\n  lut size: {self.stage_lut.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class SRLutR90(SRNetBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SRLutR90, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy(
 | 
			
		||||
        stage_lut
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage_lut.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage_lut.shape[0]-1)
 | 
			
		||||
        lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = x.reshape(b*c, 1, h, w)
 | 
			
		||||
        output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
        output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage_lut)
 | 
			
		||||
        for rotations_count in range(1, 4):
 | 
			
		||||
            rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output /= 4
 | 
			
		||||
        output = output.reshape(b, c, h*self.scale, w*self.scale)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}\n  lut size: {self.stage_lut.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SRLutR90Y(SRNetBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SRLutR90Y, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
        self.rgb_to_ycbcr = layers.RgbToYcbcr()
 | 
			
		||||
        self.ycbcr_to_rgb = layers.YcbcrToRgb()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy( 
 | 
			
		||||
        stage_lut
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage_lut.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage_lut.shape[0]-1)
 | 
			
		||||
        lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
    
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = self.rgb_to_ycbcr(x)
 | 
			
		||||
        y = x[:,0:1,:,:]
 | 
			
		||||
        cbcr = x[:,1:,:,:]
 | 
			
		||||
        cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
 | 
			
		||||
        
 | 
			
		||||
        output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
        output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
 | 
			
		||||
        for rotations_count in range(1,4):
 | 
			
		||||
            rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output /= 4
 | 
			
		||||
        output = torch.cat([output, cbcr_scaled], dim=1)  
 | 
			
		||||
        output = self.ycbcr_to_rgb(output).clamp(0, 255)        
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}\n  lut size: {self.stage_lut.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
 | 
			
		||||
class SRLutR90YCbCr(SRNetBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, 
 | 
			
		||||
        quantization_interval,
 | 
			
		||||
        scale
 | 
			
		||||
    ):
 | 
			
		||||
        super(SRLutR90YCbCr, self).__init__()     
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
        self.quantization_interval = quantization_interval
 | 
			
		||||
        self.stage_lut = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
 | 
			
		||||
        self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_from_numpy( 
 | 
			
		||||
        stage_lut
 | 
			
		||||
    ):   
 | 
			
		||||
        scale = int(stage_lut.shape[-1])
 | 
			
		||||
        quantization_interval = 256//(stage_lut.shape[0]-1)
 | 
			
		||||
        lut_model = SRLutR90YCbCr(quantization_interval=quantization_interval, scale=scale)
 | 
			
		||||
        lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
 | 
			
		||||
        return lut_model
 | 
			
		||||
    
 | 
			
		||||
    def forward_stage(self, x, scale, percieve_pattern, lut):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        x = percieve_pattern(x)   
 | 
			
		||||
        x = select_index_4dlut_tetrahedral(index=x, lut=lut)
 | 
			
		||||
        x = round_func(x)
 | 
			
		||||
        x = x.reshape(b, c, h, w, scale, scale)
 | 
			
		||||
        x = x.permute(0,1,2,4,3,5)
 | 
			
		||||
        x = x.reshape(b, c, h*scale, w*scale)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, config=None):
 | 
			
		||||
        b,c,h,w = x.shape
 | 
			
		||||
        y = x[:,0:1,:,:]
 | 
			
		||||
        cbcr = x[:,1:,:,:]
 | 
			
		||||
        cbcr_scaled = F.interpolate(cbcr, size=[h*self.scale, w*self.scale], mode='bilinear')
 | 
			
		||||
        output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=torch.float32, device=x.device)
 | 
			
		||||
        output += self.forward_stage(y, self.scale, self._extract_pattern_S, self.stage_lut)
 | 
			
		||||
        for rotations_count in range(1,4):
 | 
			
		||||
            rotated = torch.rot90(y, k=rotations_count, dims=[2, 3])
 | 
			
		||||
            output += torch.rot90(self.forward_stage(rotated, self.scale, self._extract_pattern_S, self.stage_lut), k=-rotations_count, dims=[2, 3])
 | 
			
		||||
        output /= 4
 | 
			
		||||
        output = torch.cat([output, cbcr_scaled], dim=1).clamp(0, 255)  
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}\n  lut size: {self.stage_lut.shape}"
 | 
			
		||||
 | 
			
		||||
    def get_loss_fn(self):
 | 
			
		||||
        def loss_fn(pred, target):
 | 
			
		||||
            return F.mse_loss(pred/255, target/255)
 | 
			
		||||
        return loss_fn
 | 
			
		||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
					Loading…
					
					
				
		Reference in New Issue