From ca58e33209be1e74dad25e690b8af65d6b57fc3c Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Mon, 20 May 2024 08:55:32 +0400 Subject: [PATCH] bugfix --- src/models/srnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/srnet.py b/src/models/srnet.py index c2af809..d759ef1 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -115,7 +115,7 @@ class SRNetR90Y(nn.Module): output += self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S) for rotations_count in range(1,4): rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - output += torch.rot90(self.forward_stage(x, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) + output += torch.rot90(self.forward_stage(rotated, self.scale, self._unfold_pattern_S, self.stage1_S), k=-rotations_count, dims=[2, 3]) output /= 4 output = torch.cat([output, cbcr_scaled], dim=1) output = self.ycbcr_to_rgb(output).clamp(0, 255)