main
protsenkovi 6 months ago
parent 53fa827515
commit ca58e33209

@ -115,7 +115,7 @@ class SRNetR90Y(nn.Module):
output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output += torch.rot90(self.forward_stage(rotated, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
output /= 4
output = torch.cat([output, cbcr_scaled], dim=1)
output = self.ycbcr_to_rgb(output).clamp(0, 255)

Loading…
Cancel
Save