diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 0f87a54..e500c59 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -92,7 +92,7 @@ class SDYNetx2(nn.Module): output_1 += y output_1 /= 4*3 - output_1 = output_1.view(b, c, h, w) + output_1 = output_1.view(b*c, 1, h, w) x = output_1 output_2 = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)