From 4590ba367b922e2f879adf90edfcc53493ec5850 Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Sat, 2 Nov 2024 11:36:56 +0400 Subject: [PATCH] update --- explore.ipynb | 90 ++++++++++++++++++++++++++++++++++++-------- src/models/models.py | 50 +++++++++++++----------- 2 files changed, 103 insertions(+), 37 deletions(-) diff --git a/explore.ipynb b/explore.ipynb index cb4dfe1..34b6c48 100644 --- a/explore.ipynb +++ b/explore.ipynb @@ -15200,23 +15200,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[tensor([[[[ 0., 16., 32., 48.],\n", - " [ 64., 80., 96., 112.],\n", - " [128., 144., 160., 176.],\n", - " [192., 208., 224., 240.]]]]),\n", - " tensor([[[[ 0., 1., 2., 3.],\n", - " [ 4., 5., 6., 7.],\n", - " [ 8., 9., 10., 11.],\n", - " [12., 13., 14., 15.]]]])]" + "((tensor([[[[ 0., 16., 32., 48.],\n", + " [ 64., 80., 96., 112.],\n", + " [128., 144., 160., 176.],\n", + " [192., 208., 224., 240.]]]]),\n", + " tensor([[[[ 0., 1., 2., 3.],\n", + " [ 4., 5., 6., 7.],\n", + " [ 8., 9., 10., 11.],\n", + " [12., 13., 14., 15.]]]])),\n", + " (tensor([[[[ 0., 16., 32., 48.],\n", + " [ 64., 80., 96., 112.],\n", + " [128., 144., 160., 176.],\n", + " [192., 208., 224., 240.]]]]),\n", + " tensor([[[[ 0., 1., 2., 3.],\n", + " [ 4., 5., 6., 7.],\n", + " [ 8., 9., 10., 11.],\n", + " [12., 13., 14., 15.]]]])))" ] }, - "execution_count": 9, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -15230,18 +15238,68 @@ "def bit_plane_slicing(x, bit_mask='11110000'):\n", " m = int(bit_mask, 2)\n", " masks = [m, 255-m]\n", - " images = []\n", - " for mask in masks:\n", - " images.append((x.type(torch.LongTensor) & mask).type(torch.FloatTensor).to(x.device))\n", - " return images\n", + " msb = (x.type(torch.LongTensor) & m).type(torch.FloatTensor).to(x.device)\n", + " lsb = (x.type(torch.LongTensor) & (255-m)).type(torch.FloatTensor).to(x.device)\n", + " return msb, lsb\n", "\n", - "bit_plane_slicing(a)" + "def bit_plane_slicing2(x):\n", + " lsb = a % 16\n", + " msb = a - lsb\n", + " return msb, lsb\n", + "\n", + "bit_plane_slicing(a), bit_plane_slicing2(a)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, + "outputs": [], + "source": [ + "n = 4000\n", + "a = (torch.arange(n**2) + torch.arange(start=0, end=16**2, step=16)).view(1,1,n,n).type(torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "36.3 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "bit_plane_slicing(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11 µs ± 242 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "bit_plane_slicing2(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, "outputs": [ { "data": { @@ -15256,7 +15314,7 @@ " [12., 13., 14., 15.]]]]))" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/src/models/models.py b/src/models/models.py index 3a23370..8713321 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -108,30 +108,38 @@ class HDBNetBase(SRBase): self._extract_pattern_3D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1],[2,2]], center=[0,0], window_size=3) self._extract_pattern_3B = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,2],[2,1]], center=[0,0], window_size=3) self._extract_pattern_2H = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1]], center=[0,0], window_size=2) - self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2) + self._extract_pattern_2D = layers.PercievePattern(receptive_field_idxes=[[0,0],[1,1]], center=[0,0], window_size=2) def forward(self, x, script_config=None): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) + # 1. check equal to bit_plane_slicing(batch_L255, bit_mask='11110000') ok + # 2. inference in 0,1 to -1,1 range | lsb = x % 16 - msb = x - lsb - output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device) + msb = x - lsb + output_msb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device) for rotations_count in range(4): rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3]) rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3]) - output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \ - self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \ - self.stage_3B( rotated_msb, self._extract_pattern_3B ) - output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \ - self.stage_2D( rotated_lsb, self._extract_pattern_2D ) - output_msb /= 3 - output_lsb /= 2 - if not script_config is None and script_config.current_iter % script_config.display_step == 0: - script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter) - script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter) - output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255) - output /= 4 - x = output + rotated_output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \ + self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \ + self.stage_3B( rotated_msb, self._extract_pattern_3B ) + rotated_output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \ + self.stage_2D( rotated_lsb, self._extract_pattern_2D ) + rotated_output_msb /= 3 + rotated_output_lsb /= 2 + output_msb += torch.rot90(rotated_output_msb, k=-rotations_count, dims=[2, 3]) + output_lsb += torch.rot90(rotated_output_lsb, k=-rotations_count, dims=[2, 3]) + output_msb /= 4 + output_lsb /= 4 + + if not script_config is None and script_config.current_iter % script_config.display_step == 0: + script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter) + script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter) + + x = nn.Upsample(scale_factor=self.config.upscale_factor, mode='nearest')(x) + (output_msb*16 + output_lsb - 127) + x = x.clamp(0, 255) x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor) return x @@ -144,11 +152,11 @@ class HDBNet(HDBNetBase): def __init__(self, config): super(HDBNet, self).__init__() self.config = config - self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) - self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) - self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) - self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) - self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) + self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor) class HDBLut(HDBNetBase): def __init__(self, config):