vlpr 7 months ago
parent 073e001784
commit 80d57261e7

@ -56,7 +56,7 @@ class SRLutR90(nn.Module):
): ):
scale = int(stage_lut.shape[-1]) scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1) quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale) lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model return lut_model
@ -96,7 +96,7 @@ class SRLutR90Y(nn.Module):
): ):
scale = int(stage_lut.shape[-1]) scale = int(stage_lut.shape[-1])
quantization_interval = 256//(stage_lut.shape[0]-1) quantization_interval = 256//(stage_lut.shape[0]-1)
lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale) lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32)) lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
return lut_model return lut_model

Loading…
Cancel
Save