master
			
			
		
		
						commit
						a27bdfbbcd
					
				@ -0,0 +1,28 @@
 | 
			
		||||
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
 | 
			
		||||
 | 
			
		||||
ARG USER
 | 
			
		||||
ARG GROUP
 | 
			
		||||
ARG UID
 | 
			
		||||
ARG GID
 | 
			
		||||
 | 
			
		||||
RUN groupadd -g ${GID} ${GROUP}
 | 
			
		||||
RUN useradd -u ${UID} -g ${GROUP} -s /bin/bash -m ${USER} 
 | 
			
		||||
 | 
			
		||||
RUN mkdir /wd
 | 
			
		||||
RUN chown ${USER}:${GROUP} /wd
 | 
			
		||||
WORKDIR /wd
 | 
			
		||||
 | 
			
		||||
USER ${UID}:${GID}
 | 
			
		||||
 | 
			
		||||
#RUN conda init bash
 | 
			
		||||
#RUN conda create -n jupyter-env jupyterlab -y
 | 
			
		||||
#RUN echo "conda activate jupyter-env" >> /home/${USER}/.bashrc
 | 
			
		||||
 | 
			
		||||
RUN pip install jupyterlab matplotlib einops scikit-learn
 | 
			
		||||
 | 
			
		||||
EXPOSE 9000
 | 
			
		||||
 | 
			
		||||
SHELL ["/bin/bash", "--login", "-i", "-c"]
 | 
			
		||||
ENV SHELL=/bin/bash
 | 
			
		||||
 | 
			
		||||
CMD jupyter lab --ip 0.0.0.0 --port 9000
 | 
			
		||||
											
												
													File diff suppressed because one or more lines are too long
												
											
										
									
								@ -0,0 +1,8 @@
 | 
			
		||||
#/bin/bash
 | 
			
		||||
 | 
			
		||||
docker build . \
 | 
			
		||||
	-t ${USER}_pytorch \
 | 
			
		||||
	--build-arg USER=${USER} \
 | 
			
		||||
	--build-arg GROUP=${USER} \
 | 
			
		||||
	--build-arg UID=$(id -u ${USER}) \
 | 
			
		||||
	--build-arg GID=$(id -g ${USER})
 | 
			
		||||
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								@ -0,0 +1 @@
 | 
			
		||||
docker run -d --gpus all -p 9000:9000 -v $(pwd):/wd ${USER}_pytorch
 | 
			
		||||
@ -0,0 +1,100 @@
 | 
			
		||||
import torch
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
def imshow(tensor, figsize=None, title="", **args):
 | 
			
		||||
    figsize = figsize if figsize else (13*0.8,5*0.8)
 | 
			
		||||
    
 | 
			
		||||
    if type(tensor) is list:
 | 
			
		||||
        for idx, el in enumerate(tensor):
 | 
			
		||||
            imshow(el, figsize=figsize, title=title, **args)
 | 
			
		||||
            plt.suptitle("{} {}".format(idx, title))
 | 
			
		||||
        return
 | 
			
		||||
    if len(tensor.shape)==4:
 | 
			
		||||
        for idx, el in enumerate(torch.squeeze(tensor, dim=1)):
 | 
			
		||||
            imshow(el, figsize=figsize, title=title, **args)
 | 
			
		||||
            plt.suptitle("{} {}".format(idx, title))
 | 
			
		||||
        return
 | 
			
		||||
    
 | 
			
		||||
    tensor = tensor.detach().cpu() if type(tensor) == torch.Tensor else tensor
 | 
			
		||||
    if tensor.dtype == torch.complex64:
 | 
			
		||||
        f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]})
 | 
			
		||||
        real_im = ax[0].imshow(tensor.real, **args)
 | 
			
		||||
        imag_im = ax[3].imshow(tensor.imag, **args)
 | 
			
		||||
        box = ax[1].get_position()
 | 
			
		||||
        box.x0 = box.x0 - 0.02
 | 
			
		||||
        box.x1 = box.x1 - 0.03
 | 
			
		||||
        ax[1].set_position(box)
 | 
			
		||||
        box = ax[4].get_position()
 | 
			
		||||
        box.x0 = box.x0 - 0.02
 | 
			
		||||
        box.x1 = box.x1 - 0.03
 | 
			
		||||
        ax[4].set_position(box)
 | 
			
		||||
        ax[0].set_title("real");
 | 
			
		||||
        ax[3].set_title("imag");
 | 
			
		||||
        f.colorbar(real_im, ax[1]);
 | 
			
		||||
        f.colorbar(imag_im, ax[4]);
 | 
			
		||||
        f.suptitle(title)
 | 
			
		||||
        ax[2].remove()
 | 
			
		||||
        return f, ax
 | 
			
		||||
    else:
 | 
			
		||||
        f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)
 | 
			
		||||
        im = ax[0].imshow(tensor, **args)
 | 
			
		||||
        f.colorbar(im, ax[1])
 | 
			
		||||
        f.suptitle(title)
 | 
			
		||||
        return f, ax
 | 
			
		||||
    
 | 
			
		||||
def perm_roll(im, axis, amount):
 | 
			
		||||
  permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
 | 
			
		||||
  return torch.index_select(im, axis, permutation)
 | 
			
		||||
 | 
			
		||||
def shift_left(im):
 | 
			
		||||
  tt = perm_roll(im, axis=-2, amount=-(im.shape[-2]+1)//2)
 | 
			
		||||
  tt = perm_roll(tt, axis=-1, amount=-(im.shape[-1]+1)//2)
 | 
			
		||||
  return tt
 | 
			
		||||
 | 
			
		||||
def shift_right(im):
 | 
			
		||||
  tt = perm_roll(im, axis=-2, amount=(im.shape[-2]+1)//2)
 | 
			
		||||
  tt = perm_roll(tt, axis=-1, amount=(im.shape[-1]+1)//2)
 | 
			
		||||
  return tt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pad_zeros(input, size):
 | 
			
		||||
  h, w = input.shape[-2:]
 | 
			
		||||
  th, tw = size
 | 
			
		||||
    
 | 
			
		||||
  if len(input.shape) == 2:
 | 
			
		||||
    gg = torch.zeros(size, device=input.device)
 | 
			
		||||
    x, y = int(th/2 - h/2), int(tw/2 - w/2)
 | 
			
		||||
    gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:]
 | 
			
		||||
    
 | 
			
		||||
  if len(input.shape) == 4:
 | 
			
		||||
    gg = torch.zeros(input.shape[:2] + size, device=input.device)
 | 
			
		||||
    x, y = int(th/2 - h/2), int(tw/2 - w/2)
 | 
			
		||||
    gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]
 | 
			
		||||
 | 
			
		||||
  return gg
 | 
			
		||||
 | 
			
		||||
def unpad_zeros(input, size):
 | 
			
		||||
  h, w = input.shape[-2:]
 | 
			
		||||
  th, tw = size
 | 
			
		||||
  dx,dy = h-th, w-tw
 | 
			
		||||
    
 | 
			
		||||
  if len(input.shape) == 2: 
 | 
			
		||||
    gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]
 | 
			
		||||
    
 | 
			
		||||
  if len(input.shape) == 4:
 | 
			
		||||
    gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]
 | 
			
		||||
  return gg
 | 
			
		||||
 | 
			
		||||
def circular_aperture(h, w, r=None, is_inv=False):
 | 
			
		||||
    if r is None:
 | 
			
		||||
        r = min(h//2, w//2)
 | 
			
		||||
    x, y = torch.meshgrid(torch.arange(-h//2, h//2), torch.arange(-w//2, w//2), indexing='ij')
 | 
			
		||||
    circle_dist = torch.sqrt(x**2 + y**2)
 | 
			
		||||
    if is_inv:
 | 
			
		||||
        circle_aperture = torch.where(circle_dist<r, torch.zeros_like(circle_dist), torch.ones_like(circle_dist))
 | 
			
		||||
    else:
 | 
			
		||||
        circle_aperture = torch.where(circle_dist<r, torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
 | 
			
		||||
    return circle_aperture
 | 
			
		||||
 | 
			
		||||
def to_class_labels(softmax_distibutions):
 | 
			
		||||
    return torch.argmax(softmax_distibutions, dim=1).cpu()
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue