From 80d57261e7430c8be9c3e6afa45cc4e17c2d16fd Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 14:14:56 +0000 Subject: [PATCH] fix --- src/models/srlut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/srlut.py b/src/models/srlut.py index 11bee80..0f8e2fb 100644 --- a/src/models/srlut.py +++ b/src/models/srlut.py @@ -56,7 +56,7 @@ class SRLutR90(nn.Module): ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale) + lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model @@ -96,7 +96,7 @@ class SRLutR90Y(nn.Module): ): scale = int(stage_lut.shape[-1]) quantization_interval = 256//(stage_lut.shape[0]-1) - lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) + lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) return lut_model