import torch import matplotlib.pyplot as plt def imshow(tensor, figsize=None, title="", **args): figsize = figsize if figsize else (13*0.8,5*0.8) if type(tensor) is list: for idx, el in enumerate(tensor): imshow(el, figsize=figsize, title=title, **args) plt.suptitle("{} {}".format(idx, title)) return if len(tensor.shape)==4: for idx, el in enumerate(torch.squeeze(tensor, dim=1)): imshow(el, figsize=figsize, title=title, **args) plt.suptitle("{} {}".format(idx, title)) return tensor = tensor.detach().cpu() if type(tensor) == torch.Tensor else tensor if tensor.dtype == torch.complex64: f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]}) real_im = ax[0].imshow(tensor.real, **args) imag_im = ax[3].imshow(tensor.imag, **args) box = ax[1].get_position() box.x0 = box.x0 - 0.02 box.x1 = box.x1 - 0.03 ax[1].set_position(box) box = ax[4].get_position() box.x0 = box.x0 - 0.02 box.x1 = box.x1 - 0.03 ax[4].set_position(box) ax[0].set_title("real"); ax[3].set_title("imag"); f.colorbar(real_im, ax[1]); f.colorbar(imag_im, ax[4]); f.suptitle(title) ax[2].remove() return f, ax else: f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize) im = ax[0].imshow(tensor, **args) f.colorbar(im, ax[1]) f.suptitle(title) return f, ax def perm_roll(im, axis, amount): permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0) return torch.index_select(im, axis, permutation) def shift_left(im): tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2) tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2) return tt def shift_right(im): tt = perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2) tt = perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2) return tt def pad_zeros(input, size): h, w = input.shape[-2:] th, tw = size if len(input.shape) == 2: gg = torch.zeros(size, device=input.device) x, y = int(th/2 - h/2), int(tw/2 - w/2) gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:] if len(input.shape) == 4: gg = torch.zeros(input.shape[:2] + size, device=input.device) x, y = int(th/2 - h/2), int(tw/2 - w/2) gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:] return gg def unpad_zeros(input, size): h, w = input.shape[-2:] th, tw = size dx,dy = h-th, w-tw if len(input.shape) == 2: gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)] if len(input.shape) == 4: gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw] return gg def circular_aperture(h, w, r=None, is_inv=False): if r is None: r = min(h//2, w//2) x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij') circle_dist = torch.sqrt(x**2 + y**2) if is_inv: circle_aperture = torch.where(circle_dist