main
protsenkovi 3 days ago
parent cf7bab78af
commit 4590ba367b

@ -15200,23 +15200,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[[[ 0., 16., 32., 48.],\n",
" [ 64., 80., 96., 112.],\n",
" [128., 144., 160., 176.],\n",
" [192., 208., 224., 240.]]]]),\n",
" tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])]"
"((tensor([[[[ 0., 16., 32., 48.],\n",
" [ 64., 80., 96., 112.],\n",
" [128., 144., 160., 176.],\n",
" [192., 208., 224., 240.]]]]),\n",
" tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])),\n",
" (tensor([[[[ 0., 16., 32., 48.],\n",
" [ 64., 80., 96., 112.],\n",
" [128., 144., 160., 176.],\n",
" [192., 208., 224., 240.]]]]),\n",
" tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])))"
]
},
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -15230,18 +15238,68 @@
"def bit_plane_slicing(x, bit_mask='11110000'):\n",
" m = int(bit_mask, 2)\n",
" masks = [m, 255-m]\n",
" images = []\n",
" for mask in masks:\n",
" images.append((x.type(torch.LongTensor) & mask).type(torch.FloatTensor).to(x.device))\n",
" return images\n",
" msb = (x.type(torch.LongTensor) & m).type(torch.FloatTensor).to(x.device)\n",
" lsb = (x.type(torch.LongTensor) & (255-m)).type(torch.FloatTensor).to(x.device)\n",
" return msb, lsb\n",
"\n",
"bit_plane_slicing(a)"
"def bit_plane_slicing2(x):\n",
" lsb = a % 16\n",
" msb = a - lsb\n",
" return msb, lsb\n",
"\n",
"bit_plane_slicing(a), bit_plane_slicing2(a)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n = 4000\n",
"a = (torch.arange(n**2) + torch.arange(start=0, end=16**2, step=16)).view(1,1,n,n).type(torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36.3 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"bit_plane_slicing(a)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11 µs ± 242 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"bit_plane_slicing2(a)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
@ -15256,7 +15314,7 @@
" [12., 13., 14., 15.]]]]))"
]
},
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}

@ -113,25 +113,33 @@ class HDBNetBase(SRBase):
def forward(self, x, script_config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
# 1. check equal to bit_plane_slicing(batch_L255, bit_mask='11110000') ok
# 2. inference in 0,1 to -1,1 range |
lsb = x % 16
msb = x - lsb
output = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
output_msb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.config.upscale_factor, w*self.config.upscale_factor], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated_msb = torch.rot90(msb, k=rotations_count, dims=[2, 3])
rotated_lsb = torch.rot90(lsb, k=rotations_count, dims=[2, 3])
output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \
self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \
self.stage_3B( rotated_msb, self._extract_pattern_3B )
output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \
self.stage_2D( rotated_lsb, self._extract_pattern_2D )
output_msb /= 3
output_lsb /= 2
if not script_config is None and script_config.current_iter % script_config.display_step == 0:
script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter)
script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter)
output += torch.rot90(output_msb + output_lsb, k=-rotations_count, dims=[2, 3]).clamp(0, 255)
output /= 4
x = output
rotated_output_msb = self.stage_3H( rotated_msb, self._extract_pattern_3H ) + \
self.stage_3D( rotated_msb, self._extract_pattern_3D ) + \
self.stage_3B( rotated_msb, self._extract_pattern_3B )
rotated_output_lsb = self.stage_2H( rotated_lsb, self._extract_pattern_2H ) + \
self.stage_2D( rotated_lsb, self._extract_pattern_2D )
rotated_output_msb /= 3
rotated_output_lsb /= 2
output_msb += torch.rot90(rotated_output_msb, k=-rotations_count, dims=[2, 3])
output_lsb += torch.rot90(rotated_output_lsb, k=-rotations_count, dims=[2, 3])
output_msb /= 4
output_lsb /= 4
if not script_config is None and script_config.current_iter % script_config.display_step == 0:
script_config.writer.add_histogram('s1_output_lsb', output_lsb.detach().cpu().numpy(), script_config.current_iter)
script_config.writer.add_histogram('s1_output_msb', output_msb.detach().cpu().numpy(), script_config.current_iter)
x = nn.Upsample(scale_factor=self.config.upscale_factor, mode='nearest')(x) + (output_msb*16 + output_lsb - 127)
x = x.clamp(0, 255)
x = x.reshape(b, c, h*self.config.upscale_factor, w*self.config.upscale_factor)
return x
@ -144,11 +152,11 @@ class HDBNet(HDBNetBase):
def __init__(self, config):
super(HDBNet, self).__init__()
self.config = config
self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=15, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3H.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3D.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_3B.stage = layers.LinearUpscaleBlockNet(in_features=3, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2H.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
self.stage_2D.stage = layers.LinearUpscaleBlockNet(in_features=2, input_max_value=255, output_max_value=15, hidden_dim=self.config.hidden_dim, layers_count=self.config.layers_count, upscale_factor=self.config.upscale_factor)
class HDBLut(HDBNetBase):
def __init__(self, config):

Loading…
Cancel
Save