diff --git a/src/models/srlut.py b/src/models/srlut.py index 11bee80..0f8e2fb 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -56,7 +56,7 @@ class SRLutR90(nn.Module): ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale) + lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model @@ -96,7 +96,7 @@ class SRLutR90Y(nn.Module): ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) + lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model