optimization

main
Vladimir Protsenko 6 months ago
parent dca830bbe0
commit 801402503a

@ -23,13 +23,14 @@ class SDYNetx1(nn.Module):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
output += self.stage1_S(x)
output += self.stage1_D(x)
output += self.stage1_Y(x)
for rotations_count in range(1, 4):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_S(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_D(rotated), k=-rotations_count, dims=[-2, -1])
output += torch.rot90(self.stage1_Y(rotated), k=-rotations_count, dims=[-2, -1])
output /= 4*3
x = output
x = round_func(x)

Loading…
Cancel
Save