|  |  | @ -8,6 +8,7 @@ from common import layers | 
			
		
	
		
		
			
				
					
					|  |  |  | from pathlib import Path |  |  |  | from pathlib import Path | 
			
		
	
		
		
			
				
					
					|  |  |  | from . import sdylut |  |  |  | from . import sdylut | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | class SDYNetx1(nn.Module): |  |  |  | class SDYNetx1(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): |  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): | 
			
		
	
		
		
			
				
					
					|  |  |  |         super(SDYNetx1, self).__init__() |  |  |  |         super(SDYNetx1, self).__init__() | 
			
		
	
	
		
		
			
				
					|  |  | @ -29,6 +30,101 @@ class SDYNetx1(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |         x = x.reshape(b, c, h*scale, w*scale) |  |  |  |         x = x.reshape(b, c, h*scale, w*scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |         return x |  |  |  |         return x | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def forward(self, x): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b*c, 1, h, w) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage1_S) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage1_D) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage1_Y) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output /= 3 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = output | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = round_func(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h*self.scale, w*self.scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return x | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def get_lut_model(self, quantization_interval=16, batch_size=2**10): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return lut_model | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class SDYNetx2(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         super(SDYNetx2, self).__init__() | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.scale = scale    | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage2_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage2_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage2_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def forward_stage(self, x, scale, percieve_pattern, stage): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = percieve_pattern(x)    | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = stage(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = round_func(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h, w, scale, scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.permute(0,1,2,4,3,5) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h*scale, w*scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return x | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def forward(self, x): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b*c, 1, h, w) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output = torch.zeros([b*c, 1, h, w], dtype=x.dtype, device=x.device) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, 1, self._extract_pattern_S, self.stage1_S) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, 1, self._extract_pattern_D, self.stage1_D) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, 1, self._extract_pattern_Y, self.stage1_Y) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output /= 3 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = output | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = round_func(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_S, self.stage2_S) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_D, self.stage2_D) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output += self.forward_stage(x, self.scale, self._extract_pattern_Y, self.stage2_Y) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output /= 3 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = output | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = round_func(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h*self.scale, w*self.scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return x | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def get_lut_model(self, quantization_interval=16, batch_size=2**10): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage2_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage2_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return lut_model | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class SDYNetR90x1(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         super(SDYNetR90x1, self).__init__() | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.scale = scale    | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_Y = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_S = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_D = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stage1_Y = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def forward_stage(self, x, scale, percieve_pattern, stage): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = percieve_pattern(x)    | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = stage(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = round_func(x) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h, w, scale, scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.permute(0,1,2,4,3,5) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.reshape(b, c, h*scale, w*scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return x | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     def forward(self, x): |  |  |  |     def forward(self, x): | 
			
		
	
		
		
			
				
					
					|  |  |  |         b,c,h,w = x.shape |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |         x = x.reshape(b*c, 1, h, w) |  |  |  |         x = x.reshape(b*c, 1, h, w) | 
			
		
	
	
		
		
			
				
					|  |  | @ -51,12 +147,12 @@ class SDYNetx1(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |         stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stageS = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stageD = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stageY = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         lut_model = sdylut.SDYLutx1.init_from_numpy(stageS, stageD, stageY) |  |  |  |         lut_model = sdylut.SDYLutR90x1.init_from_numpy(stageS, stageD, stageY) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |         return lut_model |  |  |  |         return lut_model | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | class SDYNetx2(nn.Module): |  |  |  | class SDYNetR90x2(nn.Module): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): |  |  |  |     def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): | 
			
		
	
		
		
			
				
					
					|  |  |  |         super(SDYNetx2, self).__init__() |  |  |  |         super(SDYNetR90x2, self).__init__() | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |         self.scale = scale    |  |  |  |         self.scale = scale    | 
			
		
	
		
		
			
				
					
					|  |  |  |         self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) |  |  |  |         self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) | 
			
		
	
		
		
			
				
					
					|  |  |  |         self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) |  |  |  |         self._extract_pattern_D = layers.PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) | 
			
		
	
	
		
		
			
				
					|  |  | @ -113,5 +209,5 @@ class SDYNetx2(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |         stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) |  |  |  |         stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) | 
			
		
	
		
		
			
				
					
					|  |  |  |         lut_model = sdylut.SDYLutx2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) |  |  |  |         lut_model = sdylut.SDYLutR90x2.init_from_numpy(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |         return lut_model |  |  |  |         return lut_model |