| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -108,30 +108,38 @@ class HDBNetBase(SRBase):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]],       center=[0,0], window_size=2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]],       center=[0,0], window_size=2)   
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]],       center=[0,0], window_size=2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x, script_config=None):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        b,c,h,w = x.shape
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.reshape(b*c, 1, h, w)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # 1. check equal to bit_plane_slicing(batch_L255, bit_mask='11110000') ok
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # 2. inference in 0,1 to -1,1 range |
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        lsb = x % 16
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        msb = x - lsb     
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        msb = x - lsb   
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output_msb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output_lsb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for rotations_count in range(4):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        self.stage_3B( rotated_msb, self._extract_pattern_3B )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        self.stage_2D( rotated_lsb, self._extract_pattern_2D )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_msb /= 3
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_lsb /= 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           if not script_config is None and script_config.current_iter % script_config.display_step == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output /= 4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = output
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                                self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                                self.stage_3B( rotated_msb, self._extract_pattern_3B )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                                self.stage_2D( rotated_lsb, self._extract_pattern_2D )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_output_msb /= 3
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           rotated_output_lsb /= 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_msb += torch.rot90(rotated_output_msb, k=-rotations_count, dims=[2, 3])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				           output_lsb += torch.rot90(rotated_output_lsb, k=-rotations_count, dims=[2, 3])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output_msb /= 4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        output_lsb /= 4
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if not script_config is None and script_config.current_iter % script_config.display_step == 0:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = nn.Upsample(scale_factor=self.config.upscale_factor, mode='nearest')(x) + (output_msb*16 + output_lsb - 127)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.clamp(0, 255)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -144,11 +152,11 @@ class HDBNet(HDBNetBase):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, config):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super(HDBNet, self).__init__()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.config = config
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255,  output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255,  output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class HDBLut(HDBNetBase):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, config):
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |