vlpr 7 months ago
parent 72ac9d99fd
commit 073e001784

@ -52,7 +52,7 @@ class SRNetR90(nn.Module):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w) x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(rotated) output += self.stage1_S(x)
for rotations_count in range(1,4): for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])
@ -90,7 +90,7 @@ class SRNetR90Y(nn.Module):
x = y.view(b, 1, h, w) x = y.view(b, 1, h, w)
output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
output += self.stage1_S(rotated) output += self.stage1_S(x)
for rotations_count in range(1,4): for rotations_count in range(1,4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])

Loading…
Cancel
Save