diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 75f0fb3..4ec03d1 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -23,13 +23,14 @@ class SDYNetx1(nn.Module): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) - for rotations_count in range(4): + output += self.stage1_S(x) + output += self.stage1_D(x) + output += self.stage1_Y(x) + for rotations_count in range(1, 4): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) - rb,rc,rh,rw = rotated.shape - output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1]) + output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1]) output += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1]) output += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1]) - output /= 4*3 x = output x = round_func(x)