@ -84,7 +84,7 @@ def forward_2x2_input_SxS_output(index, lut):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    b , c , hs , ws  =  index . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    scale  =  lut . shape [ - 1 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index  =  F . pad ( input = index ,  pad = [ 0 , 1 , 0 , 1 ] ,  mode = ' replicate ' )  #? 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  select_index_4dlut_tetrahedral 2 ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  select_index_4dlut_tetrahedral ( 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ixA  =  index ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ixB  =  torch . roll ( index ,  shifts = [ 0 ,  - 1 ] ,  dims = [ 2 , 3 ] ) ,  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        ixC  =  torch . roll ( index ,  shifts = [ - 1 ,  0 ] ,  dims = [ 2 , 3 ] ) ,  
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -101,7 +101,7 @@ def forward_rc_conv_centered(index, lut):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    window_size  =  lut . shape [ 0 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index  =  F . pad ( index ,  pad = [ window_size / / 2 ] * 4 ,  mode = ' replicate ' )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    window_indexes  =  lut . shape [ : - 1 ]          
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index  =  index . unsqueeze ( - 1  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # index = index.unsqueeze(-1  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    x  =  torch . zeros_like ( index ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    for  i  in  range ( window_indexes [ - 2 ] ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  j  in  range ( window_indexes [ - 1 ] ) : 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -118,7 +118,7 @@ def forward_rc_conv_rot90(index, lut):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    window_size  =  lut . shape [ 0 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index  =  F . pad ( index ,  pad = [ 0 ,  window_size - 1 ] * 2 ,  mode = ' replicate ' )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    window_indexes  =  lut . shape [ : - 1 ]          
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index  =  index . unsqueeze ( - 1  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # index = index.unsqueeze(-1  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    x  =  torch . zeros_like ( index ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    for  i  in  range ( window_indexes [ - 2 ] ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				        for  j  in  range ( window_indexes [ - 1 ] ) : 
 
			
		 
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
			
			 
			 
			
				@ -135,166 +135,30 @@ def forward_rc_conv_rot90(index, lut):
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				##################### UTILS ########################## 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_1dlut_linear ( ixA ,  lut ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA =  lut . shape [ 0 ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    qA =  256 / ( dimA - 1 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outDims =  lut . shape [ 1 : ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut =  lut . reshape ( dimA ,  * outDims ) . permute ( * ( i + 1  for  i  in  range ( len ( outDims ) ) ) ,  0 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index_loop_indexes =  ixA . shape  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut_loop_indexes =  lut . shape [ : - 1 ]      
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA =  ixA . view ( * ( ixA . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbA  =  torch. floor_divide ( ixA ,  qA ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbB  =  torch. floor_divide ( ixA ,  qA ) . type ( torch . int64 )  +  1  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lsb _index =  ixA  %  qA  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut =  lut . view ( * ( ( 1 , ) * len ( index_loop_indexes )  +  lut . shape ) ) . expand ( index_loop_indexes  +  lut . shape )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out A =  torch . gather ( input = lut ,  dim = - 1 ,  index = msbA )   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outB =  torch . gather ( input = lut ,  dim = - 1 ,  index = msbB )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  outA  +  ( lsb_index / qA )  *  ( outB - outA ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . squeeze( - 1  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut =  torch . clamp ( lut ,  0 ,  255 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    b, c , h , w  =  ixA . shape  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA =  ixA . flatten ( )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    L =  lut . shape [ 0 ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    Q =  256 / ( L - 1 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbA =  torch . floor_divide ( ixA ,  Q ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbB =  msbA  +  1  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbA  =  msbA. flatten ( )   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbB  =  msbB. flatten ( )   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lsb  =  ixA  %  Q      
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outA =  lut [ msbA ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out B =  lut [ msbB ]  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lsb_coef =  lsb  /  Q  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  outA  +  lsb_coef * ( outB - outA )   
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . reshape( ( b , c , h , w )  ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_1dlut_msb ( ixA ,  lut ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA  =  lut . shape [ 0 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outDims  =  lut . shape [ 1 : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . reshape ( dimA ,  * outDims ) . permute ( * ( i + 1  for  i  in  range ( len ( outDims ) ) ) ,  0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index_loop_indexes  =  ixA . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut_loop_indexes  =  lut . shape [ : - 1 ]     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA  =  ixA . view ( * ( ixA . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  =  torch . floor_divide ( ixA ,  256 / ( dimA - 1 ) ) . type ( torch . int64 )  *  dimA * * 0 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . view ( * ( ( 1 , ) * len ( index_loop_indexes )  +  lut . shape ) ) . expand ( index_loop_indexes  +  lut . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  torch . gather ( input = lut ,  dim = - 1 ,  index = msb_index ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . squeeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_4dlut_msb ( ixA ,  ixB ,  ixC ,  ixD ,  lut ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA ,  dimB ,  dimC ,  dimD  =  lut . shape [ : 4 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    qA ,  qB ,  qC ,  qD  =  256 / ( dimA - 1 ) ,  256 / ( dimB - 1 ) ,  256 / ( dimC - 1 ) ,  256 / ( dimD - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outDims  =  lut . shape [ 4 : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . reshape ( dimA * dimB * dimC * dimD ,  * outDims ) . permute ( * ( i + 1  for  i  in  range ( len ( outDims ) ) ) ,  0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index_loop_indexes  =  ixA . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut_loop_indexes  =  lut . shape [ : - 1 ]     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA  =  ixA . view ( * ( ixA . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixB  =  ixB . view ( * ( ixB . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixC  =  ixC . view ( * ( ixC . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixD  =  ixD . view ( * ( ixD . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  =  torch . floor_divide ( ixA ,  qA )  *  dimA * * 3     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixB ,  qB )  *  dimB * * 2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixC ,  qC )  *  dimC * * 1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixD ,  qD )  *  dimD * * 0 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . view ( * ( ( 1 , ) * len ( index_loop_indexes )  +  lut . shape ) ) . expand ( index_loop_indexes  +  lut . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  torch . gather ( input = lut ,  dim = - 1 ,  index = msb_index . type ( torch . int64 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . squeeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_4dlut_linear ( ixA ,  ixB ,  ixC ,  ixD ,  lut ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA ,  dimB ,  dimC ,  dimD  =  lut . shape [ : 4 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    qA ,  qB ,  qC ,  qD  =  256 / ( dimA - 1 ) ,  256 / ( dimB - 1 ) ,  256 / ( dimC - 1 ) ,  256 / ( dimD - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outDims  =  lut . shape [ 4 : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . reshape ( dimA * dimB * dimC * dimD ,  * outDims ) . permute ( * ( i + 1  for  i  in  range ( len ( outDims ) ) ) ,  0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index_loop_indexes  =  ixA . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut_loop_indexes  =  lut . shape [ : - 1 ]     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . view ( * ( ( 1 , ) * len ( index_loop_indexes )  +  lut . shape ) ) . expand ( index_loop_indexes  +  lut . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA  =  ixA . view ( * ( ixA . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixB  =  ixB . view ( * ( ixB . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixC  =  ixC . view ( * ( ixC . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixD  =  ixD . view ( * ( ixD . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  =   torch . floor_divide ( ixA ,  qA ) . type ( torch . int64 )  *  dimA * * 3     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixB ,  qB ) . type ( torch . int64 )  *  dimB * * 2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixC ,  qC ) . type ( torch . int64 )  *  dimC * * 1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  torch . floor_divide ( ixD ,  qD ) . type ( torch . int64 )  *  dimD * * 0     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outA  =  torch . gather ( input = lut ,  dim = - 1 ,  index = msb_index ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  =   ( torch . floor_divide ( ixA ,  qA ) . type ( torch . int64 )  +  1 )  *  dimA * * 3     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  ( torch . floor_divide ( ixB ,  qB ) . type ( torch . int64 )  +  1 )  *  dimB * * 2 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  ( torch . floor_divide ( ixC ,  qC ) . type ( torch . int64 )  +  1 )  *  dimC * * 1 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msb_index  + =  ( torch . floor_divide ( ixD ,  qD ) . type ( torch . int64 )  +  1 )  *  dimD * * 0     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outB  =  torch . gather ( input = lut ,  dim = - 1 ,  index = msb_index ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lsb_coef  =  ( ( ixA + ixB + ixC + ixD ) / 4  %  qA )  /  qA 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  outA  +  lsb_coef * ( outB - outA ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . squeeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  barycentric_interpolate ( masks ,  coefs ,  vertices ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i  =  torch . all ( torch . stack ( masks ) ,  dim = 0 ,  keepdim  =  False ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    coefs  =  torch . stack ( coefs )  *  i  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    vertices  =  torch . stack ( vertices ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  ( coefs * vertices ) . sum ( 0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  i ,  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_4dlut_tetrahedral ( ixA ,  ixB ,  ixC ,  ixD ,  lut ) : 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA ,  dimB ,  dimC ,  dimD  =  lut . shape [ : 4 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    qA ,  qB ,  qC ,  qD  =  256 / ( dimA - 1 ) ,  256 / ( dimB - 1 ) ,  256 / ( dimC - 1 ) ,  256 / ( dimD - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    outDims  =  lut . shape [ 4 : ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . reshape ( dimA * dimB * dimC * dimD ,  * outDims ) . permute ( * ( i + 1  for  i  in  range ( len ( outDims ) ) ) ,  0 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    index_loop_indexes  =  ixA . shape 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut_loop_indexes  =  lut . shape [ : - 1 ]     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  lut . view ( * ( ( 1 , ) * len ( index_loop_indexes )  +  lut . shape ) ) . expand ( index_loop_indexes  +  lut . shape ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixA  =  ixA . view ( * ( ixA . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixB  =  ixB . view ( * ( ixB . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixC  =  ixC . view ( * ( ixC . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    ixD  =  ixD . view ( * ( ixD . shape  +  ( 1 , ) * len ( lut_loop_indexes ) ) ) . expand ( index_loop_indexes  +  lut_loop_indexes ) . unsqueeze ( - 1 )     
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbA  =  torch . floor_divide ( ixA ,  qA ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbB  =  torch . floor_divide ( ixB ,  qB ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbC  =  torch . floor_divide ( ixC ,  qC ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    msbD  =  torch . floor_divide ( ixD ,  qD ) . type ( torch . int64 )  
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    fa ,  fb ,  fc ,  fd  =  ixA  %  qA ,  ixB  %  qB ,  ixC  %  qC ,  ixD  %  qD 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    fab ,  fac ,  fad ,  fbc ,  fbd ,  fcd  =  fa > fb ,  fa > fc ,  fa > fd ,  fb > fc ,  fb > fd ,  fc > fd 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    strides  =  torch . tensor ( [ dimA * * 3 ,  dimB * * 2 ,  dimC * * 1 ,  dimD * * 0 ] ,  device = lut . device ) . view ( - 1 ,  * ( ( 1 , ) * len ( msbA . shape ) ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0000  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB ,    msbC ,    msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0001  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB ,    msbC ,    msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0010  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB ,    msbC + 1 ,  msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0011  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB ,    msbC + 1 ,  msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0100  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB + 1 ,  msbC ,    msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0101  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB + 1 ,  msbC ,    msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0110  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB + 1 ,  msbC + 1 ,  msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p0111  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA ,    msbB + 1 ,  msbC + 1 ,  msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1000  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB ,    msbC ,    msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1001  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB ,    msbC ,    msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1010  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB ,    msbC + 1 ,  msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1011  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB ,    msbC + 1 ,  msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1100  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB + 1 ,  msbC ,    msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1101  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB + 1 ,  msbC ,    msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1110  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB + 1 ,  msbC + 1 ,  msbD   ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    p1111  =  torch . gather ( input = lut ,  dim = - 1 ,  index = ( torch . stack ( [ msbA + 1 ,  msbB + 1 ,  msbC + 1 ,  msbD + 1 ] ) * strides ) . sum ( 0 ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i1 ,    out1  =  barycentric_interpolate ( [ fab ,  fbc ,  fcd ] ,                                             [ qA - fa ,  fa  -  fb ,  fb  -  fc ,  fc  -  fd ,  fd ] ,  [ p0000 ,  p1000 ,  p1100 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i2 ,    out2  =  barycentric_interpolate ( [ fab ,  fbc ,  fbd ,   ~ ( i1 ) ] ,                                     [ qA - fa ,  fa  -  fb ,  fb  -  fd ,  fd  -  fc ,  fc ] ,  [ p0000 ,  p1000 ,  p1100 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i3 ,    out3  =  barycentric_interpolate ( [ fab ,  fbc ,  fad ,   ~ ( i1 ) ,  ~ ( i2 ) ] ,                              [ qA - fa ,  fa  -  fd ,  fd  -  fb ,  fb  -  fc ,  fc ] ,  [ p0000 ,  p1000 ,  p1001 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i4 ,    out4  =  barycentric_interpolate ( [ fab ,  fbc ,        ~ ( i1 ) ,  ~ ( i2 ) ,  ~ ( i3 ) ] ,                       [ qA - fd ,  fd  -  fa ,  fa  -  fb ,  fb  -  fc ,  fc ] ,  [ p0000 ,  p0001 ,  p1001 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i5 ,    out5  =  barycentric_interpolate ( [ fab ,  fac ,  fbd ,   ~ ( fbc ) ] ,                                    [ qA - fa ,  fa  -  fc ,  fc  -  fb ,  fb  -  fd ,  fd ] ,  [ p0000 ,  p1000 ,  p1010 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i6 ,    out6  =  barycentric_interpolate ( [ fab ,  fac ,  fcd ,   ~ ( fbc ) ,  ~ ( i5 ) ] ,                             [ qA - fa ,  fa  -  fc ,  fc  -  fd ,  fd  -  fb ,  fb ] ,  [ p0000 ,  p1000 ,  p1010 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i7 ,    out7  =  barycentric_interpolate ( [ fab ,  fac ,  fad ,   ~ ( fbc ) ,  ~ ( i5 ) ,  ~ ( i6 ) ] ,                      [ qA - fa ,  fa  -  fd ,  fd  -  fc ,  fc  -  fb ,  fb ] ,  [ p0000 ,  p1000 ,  p1001 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i8 ,    out8  =  barycentric_interpolate ( [ fab ,  fac ,        ~ ( fbc ) ,  ~ ( i5 ) ,  ~ ( i6 ) ,  ~ ( i7 ) ] ,               [ qA - fd ,  fd  -  fa ,  fa  -  fc ,  fc  -  fb ,  fb ] ,  [ p0000 ,  p0001 ,  p1001 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i9 ,    out9  =  barycentric_interpolate ( [ fab ,  fbd ,        ~ ( fbc ) ,  ~ ( fac ) ] ,                            [ qA - fc ,  fc  -  fa ,  fa  -  fb ,  fb  -  fd ,  fd ] ,  [ p0000 ,  p0010 ,  p1010 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i10 ,  out10  =  barycentric_interpolate ( [ fab ,  fad ,        ~ ( fbc ) ,  ~ ( fac ) ,  ~ ( i9 ) ] ,                     [ qA - fc ,  fc  -  fa ,  fa  -  fd ,  fd  -  fb ,  fb ] ,  [ p0000 ,  p0010 ,  p1010 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i11 ,  out11  =  barycentric_interpolate ( [ fab ,  fcd ,        ~ ( fbc ) ,  ~ ( fac ) ,  ~ ( i9 ) ,  ~ ( i10 ) ] ,             [ qA - fc ,  fc  -  fd ,  fd  -  fa ,  fa  -  fb ,  fb ] ,  [ p0000 ,  p0010 ,  p0011 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i12 ,  out12  =  barycentric_interpolate ( [ fab ,             ~ ( fbc ) ,  ~ ( fac ) ,  ~ ( i9 ) ,  ~ ( i10 ) ,  ~ ( i11 ) ] ,     [ qA - fd ,  fd  -  fc ,  fc  -  fa ,  fa  -  fb ,  fb ] ,  [ p0000 ,  p0001 ,  p0011 ,  p1011 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i13 ,  out13  =  barycentric_interpolate ( [ fac ,  fcd ,  ~ ( fab ) ] ,                                          [ qA - fb ,  fb  -  fa ,  fa  -  fc ,  fc  -  fd ,  fd ] ,  [ p0000 ,  p0100 ,  p1100 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i14 ,  out14  =  barycentric_interpolate ( [ fac ,  fad ,  ~ ( fab ) ,  ~ ( i13 ) ] ,                                  [ qA - fb ,  fb  -  fa ,  fa  -  fd ,  fd  -  fc ,  fc ] ,  [ p0000 ,  p0100 ,  p1100 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i15 ,  out15  =  barycentric_interpolate ( [ fac ,  fbd ,  ~ ( fab ) ,  ~ ( i13 ) ,  ~ ( i14 ) ] ,                          [ qA - fb ,  fb  -  fd ,  fd  -  fa ,  fa  -  fc ,  fc ] ,  [ p0000 ,  p0100 ,  p0101 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i16 ,  out16  =  barycentric_interpolate ( [ fac ,       ~ ( fab ) ,  ~ ( i13 ) ,  ~ ( i14 ) ,  ~ ( i15 )  ] ,                 [ qA - fd ,  fd  -  fb ,  fb  -  fa ,  fa  -  fc ,  fc ] ,  [ p0000 ,  p0001 ,  p0101 ,  p1101 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i17 ,  out17  =  barycentric_interpolate ( [ fbc ,  fad ,  ~ ( fab ) ,  ~ ( fac ) ] ,                                  [ qA - fb ,  fb  -  fc ,  fc  -  fa ,  fa  -  fd ,  fd ] ,  [ p0000 ,  p0100 ,  p0110 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i18 ,  out18  =  barycentric_interpolate ( [ fbc ,  fcd ,  ~ ( fab ) ,  ~ ( fac ) ,  ~ ( i17 ) ] ,                          [ qA - fb ,  fb  -  fc ,  fc  -  fd ,  fd  -  fa ,  fa ] ,  [ p0000 ,  p0100 ,  p0110 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i19 ,  out19  =  barycentric_interpolate ( [ fbc ,  fbd ,  ~ ( fab ) ,  ~ ( fac ) ,  ~ ( i17 ) ,  ~ ( i18 ) ] ,                  [ qA - fb ,  fb  -  fd ,  fd  -  fc ,  fc  -  fa ,  fa ] ,  [ p0000 ,  p0100 ,  p0101 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i20 ,  out20  =  barycentric_interpolate ( [ fbc ,       ~ ( fab ) ,  ~ ( fac ) ,  ~ ( i17 ) ,  ~ ( i18 ) ,  ~ ( i19 ) ] ,          [ qA - fd ,  fd  -  fb ,  fb  -  fc ,  fc  -  fa ,  fa ] ,  [ p0000 ,  p0001 ,  p0101 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i21 ,  out21  =  barycentric_interpolate ( [ fad ,       ~ ( fab ) ,  ~ ( fac ) ,  ~ ( fbc )  ] ,                         [ qA - fc ,  fc  -  fb ,  fb  -  fa ,  fa  -  fd ,  fd ] ,  [ p0000 ,  p0010 ,  p0110 ,  p1110 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i22 ,  out22  =  barycentric_interpolate ( [ fbd ,       ~ ( fab ) ,  ~ ( fac ) ,  ~ ( fbc ) ,  ~ ( i21 ) ] ,                  [ qA - fc ,  fc  -  fb ,  fb  -  fd ,  fd  -  fa ,  fa ] ,  [ p0000 ,  p0010 ,  p0110 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i23 ,  out23  =  barycentric_interpolate ( [ fcd ,       ~ ( fab ) ,  ~ ( fac ) ,  ~ ( fbc ) ,  ~ ( i21 ) ,  ~ ( i22 ) ] ,          [ qA - fc ,  fc  -  fd ,  fd  -  fb ,  fb  -  fa ,  fa ] ,  [ p0000 ,  p0010 ,  p0011 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i24 ,  out24  =  barycentric_interpolate ( [           ~ ( fab ) ,  ~ ( fac ) ,  ~ ( fbc ) ,  ~ ( i21 ) ,  ~ ( i22 ) ,  ~ ( i23 ) ] ,  [ qA - fd ,  fd  -  fc ,  fc  -  fb ,  fb  -  fa ,  fa ] ,  [ p0000 ,  p0001 ,  p0011 ,  p0111 ,  p1111 ] ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out1  +  out2  +  out3  +  out4  +  out5  +  out6  +  out7  +  out8  +  out9  +  out10  +  out11  +  out12  +  out13  +  out14  +  out15  +  out16  +  out17  +  out18  +  out19  +  out20  +  out21  +  out22  +  out23  +  out24 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  / =  qA 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . squeeze ( - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_4dlut_tetrahedral2 ( ixA ,  ixB ,  ixC ,  ixD ,  lut ) :  #self, weight, upscale, mode, img_in, bd): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				def  select_index_4dlut_tetrahedral ( ixA ,  ixB ,  ixC ,  ixD ,  lut ) :  #self, weight, upscale, mode, img_in, bd): 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    lut  =  torch . clamp ( lut ,  0 ,  255 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    dimA ,  dimB ,  dimC ,  dimD  =  lut . shape [ : 4 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    q  =  256 / ( dimA - 1 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    L  =  dimA 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    upscale  =  lut . shape [ - 1 ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    weight  =  lut 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    weight  =  lut . reshape ( L * * 4 , upscale , upscale ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    img_a1  =  torch . floor_divide ( ixA ,  q ) . type ( torch . int64 ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    img_b1  =  torch . floor_divide ( ixB ,  q ) . type ( torch . int64 ) 
 
			
		 
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
			
			 
			 
			
				@ -404,6 +268,7 @@ def select_index_4dlut_tetrahedral2(ixA, ixB, ixC, ixD, lut): #self, weight, ups
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    i24  =  i  =  torch . all ( torch . cat ( [ ~ ( fab ) ,  ~ ( fac ) ,  ~ ( fbc ) ,  ~ i21 [ : ,  None ] ,  ~ i22 [ : ,  None ] ,  ~ i23 [ : ,  None ] ] ,  dim = 1 ) ,  dim = 1 ) ;  out [ i ]  =  ( q  -  fd [ i ] )  *  p0000 [ i ]  +  ( fd [ i ]  -  fc [ i ] )  *  p0001 [ i ]  +  ( fc [ i ]  -  fb [ i ] )  *  p0011 [ i ]  +  ( fb [ i ]  -  fa [ i ] )  *  p0111 [ i ]  +  ( fa [ i ] )  *  p1111 [ i ] 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . reshape ( ( img_a1 . shape [ 0 ] ,  img_a1 . shape [ 1 ] ,  img_a1 . shape [ 2 ] ,  img_a1 . shape [ 3 ] ,  upscale ,  upscale ) ) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out . permute ( 0 ,  1 ,  2 ,  4 ,  3 ,  5 ) . reshape ( ( img_a1 . shape [ 0 ] ,  img_a1 . shape [ 1 ] ,  img_a1 . shape [ 2 ]  *  upscale ,  img_a1 . shape [ 3 ]  *  upscale )  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # out = out.permute(0, 1, 2, 4, 3, 5).reshape((img_a1.shape[0], img_a1.shape[1], img_a1.shape[2] * upscale, img_a1.shape[3] * upscale)  )
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    out  =  out  /  q 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    # print(out.shape) 
 
			
		 
		
	
		
			
				 
				 
			
			 
			 
			
				    return  out