|
|
@ -115,7 +115,7 @@ class SRNetR90Y(nn.Module):
|
|
|
|
output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
|
|
|
|
output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S)
|
|
|
|
for rotations_count in range(1,4):
|
|
|
|
for rotations_count in range(1,4):
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
|
|
|
|
output += torch.rot90(self.forward_stage(rotated, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3])
|
|
|
|
output /= 4
|
|
|
|
output /= 4
|
|
|
|
output = torch.cat([output, cbcr_scaled], dim=1)
|
|
|
|
output = torch.cat([output, cbcr_scaled], dim=1)
|
|
|
|
output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
|
|
|
output = self.ycbcr_to_rgb(output).clamp(0, 255)
|
|
|
|