diff --git a/src/models/__init__.py b/src/models/__init__.py index b1c21bb..2a8e030 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -18,6 +18,7 @@ AVAILABLE_MODELS = { 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, 'SDYNetx1': sdynet.SDYNetx1, 'SDYLutx1': sdylut.SDYLutx1, + 'SDYNetCenteredx1': sdynet.SDYNetCenteredx1, 'SDYLutCenteredx1': sdylut.SDYLutCenteredx1, 'SDYNetx2': sdynet.SDYNetx2, 'SDYLutx2': sdylut.SDYLutx2, 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, diff --git a/src/models/sdylut.py b/src/models/sdylut.py index e71fc1d..813cdbd 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -161,4 +161,67 @@ class SDYLutx2(nn.Module): f"\n stageY_1 size: {self.stageY_1.shape}" + \ f"\n stageS_2 size: {self.stageS_2.shape}" + \ f"\n stageD_2 size: {self.stageD_2.shape}" + \ - f"\n stageY_2 size: {self.stageY_2.shape}" \ No newline at end of file + f"\n stageY_2 size: {self.stageY_2.shape}" + + + +class SDYLutCenteredx1(nn.Module): + def __init__( + self, + quantization_interval, + scale + ): + super(SDYLutCenteredx1, self).__init__() + self.scale = scale + self.quantization_interval = quantization_interval + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[1,1], window_size=3) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[1,1], window_size=3) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[1,1], window_size=3) + self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) + + @staticmethod + def init_from_lut( + stageS, stageD, stageY + ): + scale = int(stageS.shape[-1]) + quantization_interval = 256//(stageS.shape[0]-1) + lut_model = SDYLutCenteredx1(quantization_interval=quantization_interval, scale=scale) + lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32)) + lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32)) + lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32)) + return lut_model + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w).type(torch.float32) + output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for rotations_count in range(4): + rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) + rb,rc,rh,rw = rotated.shape + + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS) + s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) + output += s + + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD) + d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) + output += d + + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY) + y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) + y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) + output += y + + output /= 4*3 + output = output.view(b, c, h*self.scale, w*self.scale) + return output + + def __repr__(self): + return f"{self.__class__.__name__}" + \ + f"\n stageS size: {self.stageS.shape}" + \ + f"\n stageD size: {self.stageD.shape}" + \ + f"\n stageY size: {self.stageY.shape}" \ No newline at end of file