diff --git a/src/models/srnet.py b/src/models/srnet.py index df74b5e..7c4ecd9 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -441,7 +441,7 @@ class SRMsbLsbFlipNet(SRNetBase): output_max_value=255 ) self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) - self.flip_functions = [ + self._flip_functions = [ lambda x: x, lambda x: torch.flip(x, dims=[-2]), lambda x: torch.flip(x, dims=[-1]), @@ -451,19 +451,23 @@ class SRMsbLsbFlipNet(SRNetBase): def forward(self, x, config=None): b,c,h,w = x.shape x = x.reshape(b*c, 1, h, w) - output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for flip_f in self.flip_functions: - fliped_x = flip_f(x) - fliped_lsb = fliped_x % 16 - fliped_msb = fliped_x - fliped_lsb - output_msb = self.forward_stage(fliped_msb, self.scale, self._extract_pattern_S, self.msb_fn) - output_lsb = self.forward_stage(fliped_lsb, self.scale, self._extract_pattern_S, self.lsb_fn) - if not config is None and config.current_iter % config.display_step == 0: - config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) - config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter) - output += flip_f(output_msb + output_lsb) - output /= 4 - x = output + lsb = x % 16 + msb = x - lsb + output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) + for flip_f in self._flip_functions: + rotated_msb = flip_f(msb) + rotated_lsb = flip_f(lsb) + output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, self.msb_fn) + output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, self.lsb_fn) + output_msb += flip_f(output_msb_r) + output_lsb += flip_f(output_lsb_r) + output_msb /= 4 + output_lsb /= 4 + if not config is None and config.current_iter % config.display_step == 0: + config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) + config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter) + x = output_msb + output_lsb x = x.reshape(b, c, h*self.scale, w*self.scale) return x diff --git a/src/test.py b/src/test.py index 7cdb134..d155649 100644 --- a/src/test.py +++ b/src/test.py @@ -13,7 +13,8 @@ from pathlib import Path from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from common.data import SRTrainDataset, SRTestDataset -from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop +from common.utils import logger_info +from common.metrics import PSNR, cal_ssim from common.test import test_steps from models import LoadCheckpoint torch.backends.cudnn.benchmark = True