|
|
@ -56,7 +56,7 @@ class SRLutR90(nn.Module):
|
|
|
|
):
|
|
|
|
):
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
lut_model = SRLutRot90(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model = SRLutR90(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
|
@ -96,7 +96,7 @@ class SRLutR90Y(nn.Module):
|
|
|
|
):
|
|
|
|
):
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
scale = int(stage_lut.shape[-1])
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
quantization_interval = 256//(stage_lut.shape[0]-1)
|
|
|
|
lut_model = SRLutRot90Y(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model = SRLutR90Y(quantization_interval=quantization_interval, scale=scale)
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
lut_model.stage_lut = nn.Parameter(torch.tensor(stage_lut).type(torch.float32))
|
|
|
|
return lut_model
|
|
|
|
return lut_model
|
|
|
|
|
|
|
|
|
|
|
|