From 073e0017845eb89e1a49b0cedb1a6c23fcaa4520 Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 13:42:18 +0000 Subject: [PATCH] fix --- src/models/srnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/srnet.py b/src/models/srnet.py index d7f23b5..6b8a9f7 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -52,7 +52,7 @@ class SRNetR90(nn.Module): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.stage1_S(rotated) + output += self.stage1_S(x) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3]) @@ -90,7 +90,7 @@ class SRNetR90Y(nn.Module): x = y.view(b, 1, h, w) output = torch.zeros([b, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - output += self.stage1_S(rotated) + output += self.stage1_S(x) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[2, 3])