@ -84,7 +84,7 @@ def forward_2x2_input_SxS_output(index, lut):
b , c , hs , ws = index . shape
scale = lut . shape [ - 1 ]
index = F . pad ( input = index , pad = [ 0 , 1 , 0 , 1 ] , mode = ' replicate ' ) #?
out = select_index_4dlut_tetrahedral 2 (
out = select_index_4dlut_tetrahedral (
ixA = index ,
ixB = torch . roll ( index , shifts = [ 0 , - 1 ] , dims = [ 2 , 3 ] ) ,
ixC = torch . roll ( index , shifts = [ - 1 , 0 ] , dims = [ 2 , 3 ] ) ,
@ -101,7 +101,7 @@ def forward_rc_conv_centered(index, lut):
window_size = lut . shape [ 0 ]
index = F . pad ( index , pad = [ window_size / / 2 ] * 4 , mode = ' replicate ' )
window_indexes = lut . shape [ : - 1 ]
index = index . unsqueeze ( - 1 )
# index = index.unsqueeze(-1 )
x = torch . zeros_like ( index )
for i in range ( window_indexes [ - 2 ] ) :
for j in range ( window_indexes [ - 1 ] ) :
@ -118,7 +118,7 @@ def forward_rc_conv_rot90(index, lut):
window_size = lut . shape [ 0 ]
index = F . pad ( index , pad = [ 0 , window_size - 1 ] * 2 , mode = ' replicate ' )
window_indexes = lut . shape [ : - 1 ]
index = index . unsqueeze ( - 1 )
# index = index.unsqueeze(-1 )
x = torch . zeros_like ( index )
for i in range ( window_indexes [ - 2 ] ) :
for j in range ( window_indexes [ - 1 ] ) :
@ -135,166 +135,30 @@ def forward_rc_conv_rot90(index, lut):
##################### UTILS ##########################
def select_index_1dlut_linear ( ixA , lut ) :
dimA = lut . shape [ 0 ]
qA = 256 / ( dimA - 1 )
outDims = lut . shape [ 1 : ]
lut = lut . reshape ( dimA , * outDims ) . permute ( * ( i + 1 for i in range ( len ( outDims ) ) ) , 0 )
index_loop_indexes = ixA . shape
lut_loop_indexes = lut . shape [ : - 1 ]
ixA = ixA . view ( * ( ixA . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
msbA = torch. floor_divide ( ixA , qA ) . type ( torch . int64 )
msbB = torch. floor_divide ( ixA , qA ) . type ( torch . int64 ) + 1
lsb _index = ixA % qA
lut = lut . view ( * ( ( 1 , ) * len ( index_loop_indexes ) + lut . shape ) ) . expand ( index_loop_indexes + lut . shape )
out A = torch . gather ( input = lut , dim = - 1 , index = msbA )
outB = torch . gather ( input = lut , dim = - 1 , index = msbB )
out = outA + ( lsb_index / qA ) * ( outB - outA )
out = out . squeeze( - 1 )
lut = torch . clamp ( lut , 0 , 255 )
b, c , h , w = ixA . shape
ixA = ixA . flatten ( )
L = lut . shape [ 0 ]
Q = 256 / ( L - 1 )
msbA = torch . floor_divide ( ixA , Q ) . type ( torch . int64 )
msbB = msbA + 1
msbA = msbA. flatten ( )
msbB = msbB. flatten ( )
lsb = ixA % Q
outA = lut [ msbA ]
out B = lut [ msbB ]
lsb_coef = lsb / Q
out = outA + lsb_coef * ( outB - outA )
out = out . reshape( ( b , c , h , w ) )
return out
def select_index_1dlut_msb ( ixA , lut ) :
dimA = lut . shape [ 0 ]
outDims = lut . shape [ 1 : ]
lut = lut . reshape ( dimA , * outDims ) . permute ( * ( i + 1 for i in range ( len ( outDims ) ) ) , 0 )
index_loop_indexes = ixA . shape
lut_loop_indexes = lut . shape [ : - 1 ]
ixA = ixA . view ( * ( ixA . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
msb_index = torch . floor_divide ( ixA , 256 / ( dimA - 1 ) ) . type ( torch . int64 ) * dimA * * 0
lut = lut . view ( * ( ( 1 , ) * len ( index_loop_indexes ) + lut . shape ) ) . expand ( index_loop_indexes + lut . shape )
out = torch . gather ( input = lut , dim = - 1 , index = msb_index )
out = out . squeeze ( - 1 )
return out
def select_index_4dlut_msb ( ixA , ixB , ixC , ixD , lut ) :
dimA , dimB , dimC , dimD = lut . shape [ : 4 ]
qA , qB , qC , qD = 256 / ( dimA - 1 ) , 256 / ( dimB - 1 ) , 256 / ( dimC - 1 ) , 256 / ( dimD - 1 )
outDims = lut . shape [ 4 : ]
lut = lut . reshape ( dimA * dimB * dimC * dimD , * outDims ) . permute ( * ( i + 1 for i in range ( len ( outDims ) ) ) , 0 )
index_loop_indexes = ixA . shape
lut_loop_indexes = lut . shape [ : - 1 ]
ixA = ixA . view ( * ( ixA . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixB = ixB . view ( * ( ixB . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixC = ixC . view ( * ( ixC . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixD = ixD . view ( * ( ixD . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
msb_index = torch . floor_divide ( ixA , qA ) * dimA * * 3
msb_index + = torch . floor_divide ( ixB , qB ) * dimB * * 2
msb_index + = torch . floor_divide ( ixC , qC ) * dimC * * 1
msb_index + = torch . floor_divide ( ixD , qD ) * dimD * * 0
lut = lut . view ( * ( ( 1 , ) * len ( index_loop_indexes ) + lut . shape ) ) . expand ( index_loop_indexes + lut . shape )
out = torch . gather ( input = lut , dim = - 1 , index = msb_index . type ( torch . int64 ) )
out = out . squeeze ( - 1 )
return out
def select_index_4dlut_linear ( ixA , ixB , ixC , ixD , lut ) :
dimA , dimB , dimC , dimD = lut . shape [ : 4 ]
qA , qB , qC , qD = 256 / ( dimA - 1 ) , 256 / ( dimB - 1 ) , 256 / ( dimC - 1 ) , 256 / ( dimD - 1 )
outDims = lut . shape [ 4 : ]
lut = lut . reshape ( dimA * dimB * dimC * dimD , * outDims ) . permute ( * ( i + 1 for i in range ( len ( outDims ) ) ) , 0 )
index_loop_indexes = ixA . shape
lut_loop_indexes = lut . shape [ : - 1 ]
lut = lut . view ( * ( ( 1 , ) * len ( index_loop_indexes ) + lut . shape ) ) . expand ( index_loop_indexes + lut . shape )
ixA = ixA . view ( * ( ixA . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixB = ixB . view ( * ( ixB . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixC = ixC . view ( * ( ixC . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixD = ixD . view ( * ( ixD . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
msb_index = torch . floor_divide ( ixA , qA ) . type ( torch . int64 ) * dimA * * 3
msb_index + = torch . floor_divide ( ixB , qB ) . type ( torch . int64 ) * dimB * * 2
msb_index + = torch . floor_divide ( ixC , qC ) . type ( torch . int64 ) * dimC * * 1
msb_index + = torch . floor_divide ( ixD , qD ) . type ( torch . int64 ) * dimD * * 0
outA = torch . gather ( input = lut , dim = - 1 , index = msb_index )
msb_index = ( torch . floor_divide ( ixA , qA ) . type ( torch . int64 ) + 1 ) * dimA * * 3
msb_index + = ( torch . floor_divide ( ixB , qB ) . type ( torch . int64 ) + 1 ) * dimB * * 2
msb_index + = ( torch . floor_divide ( ixC , qC ) . type ( torch . int64 ) + 1 ) * dimC * * 1
msb_index + = ( torch . floor_divide ( ixD , qD ) . type ( torch . int64 ) + 1 ) * dimD * * 0
outB = torch . gather ( input = lut , dim = - 1 , index = msb_index )
lsb_coef = ( ( ixA + ixB + ixC + ixD ) / 4 % qA ) / qA
out = outA + lsb_coef * ( outB - outA )
out = out . squeeze ( - 1 )
return out
def barycentric_interpolate ( masks , coefs , vertices ) :
i = torch . all ( torch . stack ( masks ) , dim = 0 , keepdim = False )
coefs = torch . stack ( coefs ) * i
vertices = torch . stack ( vertices )
out = ( coefs * vertices ) . sum ( 0 )
return i , out
def select_index_4dlut_tetrahedral ( ixA , ixB , ixC , ixD , lut ) :
dimA , dimB , dimC , dimD = lut . shape [ : 4 ]
qA , qB , qC , qD = 256 / ( dimA - 1 ) , 256 / ( dimB - 1 ) , 256 / ( dimC - 1 ) , 256 / ( dimD - 1 )
outDims = lut . shape [ 4 : ]
lut = lut . reshape ( dimA * dimB * dimC * dimD , * outDims ) . permute ( * ( i + 1 for i in range ( len ( outDims ) ) ) , 0 )
index_loop_indexes = ixA . shape
lut_loop_indexes = lut . shape [ : - 1 ]
lut = lut . view ( * ( ( 1 , ) * len ( index_loop_indexes ) + lut . shape ) ) . expand ( index_loop_indexes + lut . shape )
ixA = ixA . view ( * ( ixA . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixB = ixB . view ( * ( ixB . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixC = ixC . view ( * ( ixC . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
ixD = ixD . view ( * ( ixD . shape + ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes + lut_loop_indexes ) . unsqueeze ( - 1 )
msbA = torch . floor_divide ( ixA , qA ) . type ( torch . int64 )
msbB = torch . floor_divide ( ixB , qB ) . type ( torch . int64 )
msbC = torch . floor_divide ( ixC , qC ) . type ( torch . int64 )
msbD = torch . floor_divide ( ixD , qD ) . type ( torch . int64 )
fa , fb , fc , fd = ixA % qA , ixB % qB , ixC % qC , ixD % qD
fab , fac , fad , fbc , fbd , fcd = fa > fb , fa > fc , fa > fd , fb > fc , fb > fd , fc > fd
strides = torch . tensor ( [ dimA * * 3 , dimB * * 2 , dimC * * 1 , dimD * * 0 ] , device = lut . device ) . view ( - 1 , * ( ( 1 , ) * len ( msbA . shape ) ) )
p0000 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB , msbC , msbD ] ) * strides ) . sum ( 0 ) )
p0001 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB , msbC , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p0010 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB , msbC + 1 , msbD ] ) * strides ) . sum ( 0 ) )
p0011 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB , msbC + 1 , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p0100 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB + 1 , msbC , msbD ] ) * strides ) . sum ( 0 ) )
p0101 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB + 1 , msbC , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p0110 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB + 1 , msbC + 1 , msbD ] ) * strides ) . sum ( 0 ) )
p0111 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA , msbB + 1 , msbC + 1 , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p1000 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB , msbC , msbD ] ) * strides ) . sum ( 0 ) )
p1001 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB , msbC , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p1010 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB , msbC + 1 , msbD ] ) * strides ) . sum ( 0 ) )
p1011 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB , msbC + 1 , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p1100 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB + 1 , msbC , msbD ] ) * strides ) . sum ( 0 ) )
p1101 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB + 1 , msbC , msbD + 1 ] ) * strides ) . sum ( 0 ) )
p1110 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB + 1 , msbC + 1 , msbD ] ) * strides ) . sum ( 0 ) )
p1111 = torch . gather ( input = lut , dim = - 1 , index = ( torch . stack ( [ msbA + 1 , msbB + 1 , msbC + 1 , msbD + 1 ] ) * strides ) . sum ( 0 ) )
i1 , out1 = barycentric_interpolate ( [ fab , fbc , fcd ] , [ qA - fa , fa - fb , fb - fc , fc - fd , fd ] , [ p0000 , p1000 , p1100 , p1110 , p1111 ] )
i2 , out2 = barycentric_interpolate ( [ fab , fbc , fbd , ~ ( i1 ) ] , [ qA - fa , fa - fb , fb - fd , fd - fc , fc ] , [ p0000 , p1000 , p1100 , p1101 , p1111 ] )
i3 , out3 = barycentric_interpolate ( [ fab , fbc , fad , ~ ( i1 ) , ~ ( i2 ) ] , [ qA - fa , fa - fd , fd - fb , fb - fc , fc ] , [ p0000 , p1000 , p1001 , p1101 , p1111 ] )
i4 , out4 = barycentric_interpolate ( [ fab , fbc , ~ ( i1 ) , ~ ( i2 ) , ~ ( i3 ) ] , [ qA - fd , fd - fa , fa - fb , fb - fc , fc ] , [ p0000 , p0001 , p1001 , p1101 , p1111 ] )
i5 , out5 = barycentric_interpolate ( [ fab , fac , fbd , ~ ( fbc ) ] , [ qA - fa , fa - fc , fc - fb , fb - fd , fd ] , [ p0000 , p1000 , p1010 , p1110 , p1111 ] )
i6 , out6 = barycentric_interpolate ( [ fab , fac , fcd , ~ ( fbc ) , ~ ( i5 ) ] , [ qA - fa , fa - fc , fc - fd , fd - fb , fb ] , [ p0000 , p1000 , p1010 , p1011 , p1111 ] )
i7 , out7 = barycentric_interpolate ( [ fab , fac , fad , ~ ( fbc ) , ~ ( i5 ) , ~ ( i6 ) ] , [ qA - fa , fa - fd , fd - fc , fc - fb , fb ] , [ p0000 , p1000 , p1001 , p1011 , p1111 ] )
i8 , out8 = barycentric_interpolate ( [ fab , fac , ~ ( fbc ) , ~ ( i5 ) , ~ ( i6 ) , ~ ( i7 ) ] , [ qA - fd , fd - fa , fa - fc , fc - fb , fb ] , [ p0000 , p0001 , p1001 , p1011 , p1111 ] )
i9 , out9 = barycentric_interpolate ( [ fab , fbd , ~ ( fbc ) , ~ ( fac ) ] , [ qA - fc , fc - fa , fa - fb , fb - fd , fd ] , [ p0000 , p0010 , p1010 , p1110 , p1111 ] )
i10 , out10 = barycentric_interpolate ( [ fab , fad , ~ ( fbc ) , ~ ( fac ) , ~ ( i9 ) ] , [ qA - fc , fc - fa , fa - fd , fd - fb , fb ] , [ p0000 , p0010 , p1010 , p1011 , p1111 ] )
i11 , out11 = barycentric_interpolate ( [ fab , fcd , ~ ( fbc ) , ~ ( fac ) , ~ ( i9 ) , ~ ( i10 ) ] , [ qA - fc , fc - fd , fd - fa , fa - fb , fb ] , [ p0000 , p0010 , p0011 , p1011 , p1111 ] )
i12 , out12 = barycentric_interpolate ( [ fab , ~ ( fbc ) , ~ ( fac ) , ~ ( i9 ) , ~ ( i10 ) , ~ ( i11 ) ] , [ qA - fd , fd - fc , fc - fa , fa - fb , fb ] , [ p0000 , p0001 , p0011 , p1011 , p1111 ] )
i13 , out13 = barycentric_interpolate ( [ fac , fcd , ~ ( fab ) ] , [ qA - fb , fb - fa , fa - fc , fc - fd , fd ] , [ p0000 , p0100 , p1100 , p1110 , p1111 ] )
i14 , out14 = barycentric_interpolate ( [ fac , fad , ~ ( fab ) , ~ ( i13 ) ] , [ qA - fb , fb - fa , fa - fd , fd - fc , fc ] , [ p0000 , p0100 , p1100 , p1101 , p1111 ] )
i15 , out15 = barycentric_interpolate ( [ fac , fbd , ~ ( fab ) , ~ ( i13 ) , ~ ( i14 ) ] , [ qA - fb , fb - fd , fd - fa , fa - fc , fc ] , [ p0000 , p0100 , p0101 , p1101 , p1111 ] )
i16 , out16 = barycentric_interpolate ( [ fac , ~ ( fab ) , ~ ( i13 ) , ~ ( i14 ) , ~ ( i15 ) ] , [ qA - fd , fd - fb , fb - fa , fa - fc , fc ] , [ p0000 , p0001 , p0101 , p1101 , p1111 ] )
i17 , out17 = barycentric_interpolate ( [ fbc , fad , ~ ( fab ) , ~ ( fac ) ] , [ qA - fb , fb - fc , fc - fa , fa - fd , fd ] , [ p0000 , p0100 , p0110 , p1110 , p1111 ] )
i18 , out18 = barycentric_interpolate ( [ fbc , fcd , ~ ( fab ) , ~ ( fac ) , ~ ( i17 ) ] , [ qA - fb , fb - fc , fc - fd , fd - fa , fa ] , [ p0000 , p0100 , p0110 , p0111 , p1111 ] )
i19 , out19 = barycentric_interpolate ( [ fbc , fbd , ~ ( fab ) , ~ ( fac ) , ~ ( i17 ) , ~ ( i18 ) ] , [ qA - fb , fb - fd , fd - fc , fc - fa , fa ] , [ p0000 , p0100 , p0101 , p0111 , p1111 ] )
i20 , out20 = barycentric_interpolate ( [ fbc , ~ ( fab ) , ~ ( fac ) , ~ ( i17 ) , ~ ( i18 ) , ~ ( i19 ) ] , [ qA - fd , fd - fb , fb - fc , fc - fa , fa ] , [ p0000 , p0001 , p0101 , p0111 , p1111 ] )
i21 , out21 = barycentric_interpolate ( [ fad , ~ ( fab ) , ~ ( fac ) , ~ ( fbc ) ] , [ qA - fc , fc - fb , fb - fa , fa - fd , fd ] , [ p0000 , p0010 , p0110 , p1110 , p1111 ] )
i22 , out22 = barycentric_interpolate ( [ fbd , ~ ( fab ) , ~ ( fac ) , ~ ( fbc ) , ~ ( i21 ) ] , [ qA - fc , fc - fb , fb - fd , fd - fa , fa ] , [ p0000 , p0010 , p0110 , p0111 , p1111 ] )
i23 , out23 = barycentric_interpolate ( [ fcd , ~ ( fab ) , ~ ( fac ) , ~ ( fbc ) , ~ ( i21 ) , ~ ( i22 ) ] , [ qA - fc , fc - fd , fd - fb , fb - fa , fa ] , [ p0000 , p0010 , p0011 , p0111 , p1111 ] )
i24 , out24 = barycentric_interpolate ( [ ~ ( fab ) , ~ ( fac ) , ~ ( fbc ) , ~ ( i21 ) , ~ ( i22 ) , ~ ( i23 ) ] , [ qA - fd , fd - fc , fc - fb , fb - fa , fa ] , [ p0000 , p0001 , p0011 , p0111 , p1111 ] )
out = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8 + out9 + out10 + out11 + out12 + out13 + out14 + out15 + out16 + out17 + out18 + out19 + out20 + out21 + out22 + out23 + out24
out / = qA
out = out . squeeze ( - 1 )
return out
def select_index_4dlut_tetrahedral2 ( ixA , ixB , ixC , ixD , lut ) : #self, weight, upscale, mode, img_in, bd):
def select_index_4dlut_tetrahedral ( ixA , ixB , ixC , ixD , lut ) : #self, weight, upscale, mode, img_in, bd):
lut = torch . clamp ( lut , 0 , 255 )
dimA , dimB , dimC , dimD = lut . shape [ : 4 ]
q = 256 / ( dimA - 1 )
L = dimA
upscale = lut . shape [ - 1 ]
weight = lut
weight = lut . reshape ( L * * 4 , upscale , upscale )
img_a1 = torch . floor_divide ( ixA , q ) . type ( torch . int64 )
img_b1 = torch . floor_divide ( ixB , q ) . type ( torch . int64 )
@ -404,6 +268,7 @@ def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, ups
i24 = i = torch . all ( torch . cat ( [ ~ ( fab ) , ~ ( fac ) , ~ ( fbc ) , ~ i21 [ : , None ] , ~ i22 [ : , None ] , ~ i23 [ : , None ] ] , dim = 1 ) , dim = 1 ) ; out [ i ] = ( q - fd [ i ] ) * p0000 [ i ] + ( fd [ i ] - fc [ i ] ) * p0001 [ i ] + ( fc [ i ] - fb [ i ] ) * p0011 [ i ] + ( fb [ i ] - fa [ i ] ) * p0111 [ i ] + ( fa [ i ] ) * p1111 [ i ]
out = out . reshape ( ( img_a1 . shape [ 0 ] , img_a1 . shape [ 1 ] , img_a1 . shape [ 2 ] , img_a1 . shape [ 3 ] , upscale , upscale ) )
out = out . permute ( 0 , 1 , 2 , 4 , 3 , 5 ) . reshape ( ( img_a1 . shape [ 0 ] , img_a1 . shape [ 1 ] , img_a1 . shape [ 2 ] * upscale , img_a1 . shape [ 3 ] * upscale ) )
# out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale) )
out = out / q
return out
# print(out.shape)
return out