You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
import torch
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
from datetime import datetime
|
|
|
|
def imshow(tensor, figsize=None, title="", **args):
|
|
tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor
|
|
tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor
|
|
|
|
figsize = figsize if figsize else (13*0.8,5*0.8)
|
|
|
|
if type(tensor) is list:
|
|
outs = []
|
|
for idx, el in enumerate(tensor):
|
|
f, ax = imshow(el, figsize=figsize, title=title, **args)
|
|
plt.suptitle("{} {}".format(idx, title))
|
|
outs.append([f, ax])
|
|
return outs
|
|
if len(tensor.shape)==4:
|
|
outs = []
|
|
for idx, el in enumerate(torch.squeeze(tensor, dim=1)):
|
|
f, ax = imshow(el, figsize=figsize, title=title, **args)
|
|
plt.suptitle("{} {}".format(idx, title))
|
|
outs.append([f, ax])
|
|
return outs
|
|
|
|
if tensor.dtype == torch.complex64:
|
|
f, ax = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [46.5,46.5]})
|
|
|
|
real_im = ax[0].imshow(tensor.real, **args)
|
|
imag_im = ax[1].imshow(tensor.imag, **args)
|
|
ax[0].set_title("real");
|
|
ax[1].set_title("imag");
|
|
divider = make_axes_locatable(ax[0])
|
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
f.colorbar(real_im, cax);
|
|
divider = make_axes_locatable(ax[1])
|
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
f.colorbar(imag_im, cax);
|
|
f.suptitle(title)
|
|
f.tight_layout()
|
|
return f, ax
|
|
else:
|
|
f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)
|
|
im = ax[0].imshow(tensor, **args)
|
|
f.colorbar(im, ax[1])
|
|
f.suptitle(title)
|
|
return f, ax
|
|
|
|
|
|
def perm_roll(im, axis, amount):
|
|
permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
|
|
return torch.index_select(im, axis, permutation)
|
|
|
|
def shift_left(im):
|
|
tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2)
|
|
tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2)
|
|
return tt
|
|
|
|
def shift_right(im):
|
|
tt = perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2)
|
|
tt = perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2)
|
|
return tt
|
|
|
|
|
|
def pad_zeros(input, size):
|
|
h, w = input.shape[-2:]
|
|
th, tw = size
|
|
out = torch.zeros(input.shape[:-2] + size, device=input.device)
|
|
x, y = int(th/2 - h/2), int(tw/2 - w/2)
|
|
out[..., x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[..., :,:]
|
|
return out
|
|
|
|
def unpad_zeros(input, size):
|
|
h, w = input.shape[-2:]
|
|
th, tw = size
|
|
dx,dy = h-th, w-tw
|
|
return input[..., int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]
|
|
|
|
def to_class_labels(softmax_distibutions):
|
|
return torch.argmax(softmax_distibutions, dim=1).cpu()
|
|
|
|
def circular_aperture(h, w, r=None, is_inv=False):
|
|
if r is None:
|
|
r = min(h//2, w//2)
|
|
x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij')
|
|
circle_dist = torch.sqrt(x**2 + y**2)
|
|
if is_inv:
|
|
circle_aperture = torch.where(circle_dist<r, torch.zeros_like(circle_dist), torch.ones_like(circle_dist))
|
|
else:
|
|
circle_aperture = torch.where(circle_dist<r, torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
|
|
return circle_aperture |