|  |  | @ -161,4 +161,67 @@ class SDYLutx2(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |                f"\n  stageY_1 size: {self.stageY_1.shape}" + \ |  |  |  |                f"\n  stageY_1 size: {self.stageY_1.shape}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |                f"\n  stageS_2 size: {self.stageS_2.shape}" + \ |  |  |  |                f"\n  stageS_2 size: {self.stageS_2.shape}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |                f"\n  stageD_2 size: {self.stageD_2.shape}" + \ |  |  |  |                f"\n  stageD_2 size: {self.stageD_2.shape}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |                f"\n  stageY_2 size: {self.stageY_2.shape}" |  |  |  |                f"\n  stageY_2 size: {self.stageY_2.shape}" | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class SDYLutCenteredx1(nn.Module): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def __init__( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self,  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         quantization_interval, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         scale | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     ): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         super(SDYLutCenteredx1, self).__init__()      | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.scale = scale | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.quantization_interval = quantization_interval | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[1,1], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[1,1], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[1,1], window_size=3) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     @staticmethod | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def init_from_lut( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         stageS, stageD, stageY | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     ):    | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         scale = int(stageS.shape[-1]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         quantization_interval = 256//(stageS.shape[0]-1) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model = SDYLutCenteredx1(quantization_interval=quantization_interval, scale=scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model.stageS = nn.Parameter(torch.tensor(stageS).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model.stageD = nn.Parameter(torch.tensor(stageD).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         lut_model.stageY = nn.Parameter(torch.tensor(stageY).type(torch.float32)) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return lut_model | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def forward(self, x): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         b,c,h,w = x.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         x = x.view(b*c, 1, h, w).type(torch.float32) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         for rotations_count in range(4): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             rb,rc,rh,rw = rotated.shape | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS)          | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])             | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             output += s | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD)             | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             output += d | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY)             | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             output += y | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output /= 4*3 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         output = output.view(b, c, h*self.scale, w*self.scale) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return output | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def __repr__(self): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return f"{self.__class__.__name__}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                f"\n  stageS size: {self.stageS.shape}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                f"\n  stageD size: {self.stageD.shape}" + \ | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                f"\n  stageY size: {self.stageY.shape}" |