@ -15,9 +15,9 @@ class SDYLutx1(nn.Module):
super ( SDYLutx1 , self ) . __init__ ( )
self . scale = scale
self . quantization_interval = quantization_interval
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] )
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . stageS = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )
self . stageD = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )
self . stageY = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )
@ -78,9 +78,9 @@ class SDYLutx2(nn.Module):
super ( SDYLutx2 , self ) . __init__ ( )
self . scale = scale
self . quantization_interval = quantization_interval
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] )
self . _extract_pattern_S = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 0 , 1 ] , [ 1 , 0 ] , [ 1 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_D = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 2 , 0 ] , [ 0 , 2 ] , [ 2 , 2 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . _extract_pattern_Y = PercievePattern ( receptive_field_idxes = [ [ 0 , 0 ] , [ 1 , 1 ] , [ 1 , 2 ] , [ 2 , 1 ] ] , center = [ 0 , 0 ] , window_size = 3 )
self . stageS_1 = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )
self . stageD_1 = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )
self . stageY_1 = nn . Parameter ( torch . randint ( 0 , 255 , size = ( 256 / / quantization_interval + 1 , ) * 4 + ( scale , scale ) ) . type ( torch . float32 ) )