diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 64f46db..c4b6f38 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -15,9 +15,9 @@ class SDYNetx1(nn.Module): s_pattern = [[0,0],[0,1],[1,0],[1,1]] d_pattern = [[0,0],[2,0],[0,2],[2,2]] y_pattern = [[0,0],[1,1],[1,2],[2,1]] - self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_S = layers.UpscaleBlock(receptive_field_idxes=s_pattern, center=[0,0], window_size=2, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_D = layers.UpscaleBlock(receptive_field_idxes=d_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_Y = layers.UpscaleBlock(receptive_field_idxes=y_pattern, center=[0,0], window_size=3, hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) def forward(self, x): b,c,h,w = x.shape