main
protsenkovi 5 months ago
parent 6744b97f59
commit e2db157055

@ -441,7 +441,7 @@ class SRMsbLsbFlipNet(SRNetBase):
output_max_value=255
)
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.flip_functions = [
self._flip_functions = [
lambda x: x,
lambda x: torch.flip(x, dims=[-2]),
lambda x: torch.flip(x, dims=[-1]),
@ -451,19 +451,23 @@ class SRMsbLsbFlipNet(SRNetBase):
def forward(self, x, config=None):
b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for flip_f in self.flip_functions:
fliped_x = flip_f(x)
fliped_lsb = fliped_x % 16
fliped_msb = fliped_x - fliped_lsb
output_msb = self.forward_stage(fliped_msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb = self.forward_stage(fliped_lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
lsb = x % 16
msb = x - lsb
output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for flip_f in self._flip_functions:
rotated_msb = flip_f(msb)
rotated_lsb = flip_f(lsb)
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
output_msb += flip_f(output_msb_r)
output_lsb += flip_f(output_lsb_r)
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += flip_f(output_msb + output_lsb)
output /= 4
x = output
x = output_msb + output_lsb
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x

@ -13,7 +13,8 @@ from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.test import test_steps
from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True

Loading…
Cancel
Save