main
protsenkovi 5 months ago
parent 6744b97f59
commit e2db157055

@ -441,7 +441,7 @@ class SRMsbLsbFlipNet(SRNetBase):
output_max_value=255 output_max_value=255
) )
self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self._extract_pattern_S = layers.PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.flip_functions = [ self._flip_functions = [
lambda x: x, lambda x: x,
lambda x: torch.flip(x, dims=[-2]), lambda x: torch.flip(x, dims=[-2]),
lambda x: torch.flip(x, dims=[-1]), lambda x: torch.flip(x, dims=[-1]),
@ -451,19 +451,23 @@ class SRMsbLsbFlipNet(SRNetBase):
def forward(self, x, config=None): def forward(self, x, config=None):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.reshape(b*c, 1, h, w) x = x.reshape(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) lsb = x % 16
for flip_f in self.flip_functions: msb = x - lsb
fliped_x = flip_f(x) output_msb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
fliped_lsb = fliped_x % 16 output_lsb = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
fliped_msb = fliped_x - fliped_lsb for flip_f in self._flip_functions:
output_msb = self.forward_stage(fliped_msb, self.scale, self._extract_pattern_S, self.msb_fn) rotated_msb = flip_f(msb)
output_lsb = self.forward_stage(fliped_lsb, self.scale, self._extract_pattern_S, self.lsb_fn) rotated_lsb = flip_f(lsb)
output_msb_r = self.forward_stage(rotated_msb, self.scale, self._extract_pattern_S, self.msb_fn)
output_lsb_r = self.forward_stage(rotated_lsb, self.scale, self._extract_pattern_S, self.lsb_fn)
output_msb += flip_f(output_msb_r)
output_lsb += flip_f(output_lsb_r)
output_msb /= 4
output_lsb /= 4
if not config is None and config.current_iter % config.display_step == 0: if not config is None and config.current_iter % config.display_step == 0:
config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter) config.writer.add_histogram('output_lsb', output_lsb.detach().cpu().numpy(), config.current_iter)
config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter) config.writer.add_histogram('output_msb', output_msb.detach().cpu().numpy(), config.current_iter)
output += flip_f(output_msb + output_lsb) x = output_msb + output_lsb
output /= 4
x = output
x = x.reshape(b, c, h*self.scale, w*self.scale) x = x.reshape(b, c, h*self.scale, w*self.scale)
return x return x

@ -13,7 +13,8 @@ from pathlib import Path
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from common.data import SRTrainDataset, SRTestDataset from common.data import SRTrainDataset, SRTestDataset
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.utils import logger_info
from common.metrics import PSNR, cal_ssim
from common.test import test_steps from common.test import test_steps
from models import LoadCheckpoint from models import LoadCheckpoint
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True

Loading…
Cancel
Save