You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
optical_accelerator/Propagator_QuickDraw.ipynb

26 KiB

Пример на данных QuckDraw

https://github.com/googlecreativelab/quickdraw-dataset
Используется урезанная версия с 16 классами

In [1]:
import torch
from torch import flip
from torch.nn import Module
from torch import nn
from torch.nn.functional import conv2d
import torch.nn.functional as F
from torchvision.transforms.functional import resize, InterpolationMode
from einops import rearrange

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score

import numpy as np
import math

import tqdm
from pprint import pprint, pformat
/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Служебные функции

In [2]:
# from utils import circular_aperture, imshow, pad_zeros, to_class_labels
def pad_zeros(input, size):
  h, w = input.shape[-2:]
  th, tw = size
    
  if len(input.shape) == 2:
    gg = torch.zeros(size, device=input.device)
    x, y = int(th/2 - h/2), int(tw/2 - w/2)
    gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:]
    
  if len(input.shape) == 4:
    gg = torch.zeros(input.shape[:2] + size, device=input.device)
    x, y = int(th/2 - h/2), int(tw/2 - w/2)
    gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]

  return gg

def unpad_zeros(input, size):
  h, w = input.shape[-2:]
  th, tw = size
  dx,dy = h-th, w-tw
    
  if len(input.shape) == 2: 
    gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]
    
  if len(input.shape) == 4:
    gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]
  return gg

def to_class_labels(softmax_distibutions):
    return torch.argmax(softmax_distibutions, dim=1).cpu()
In [3]:
def imshow(tensor, figsize=None, title="", **args):
    tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor
    tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor
    
    figsize = figsize if figsize else (13*0.8,5*0.8)
    
    if type(tensor) is list:
        for idx, el in enumerate(tensor):
            imshow(el, figsize=figsize, title=title, **args)
            plt.suptitle("{} {}".format(idx, title))
        return
    if len(tensor.shape)==4:
        for idx, el in enumerate(torch.squeeze(tensor, dim=1)):
            imshow(el, figsize=figsize, title=title, **args)
            plt.suptitle("{} {}".format(idx, title))
        return
    
    if tensor.dtype == torch.complex64:
        f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]})
        real_im = ax[0].imshow(tensor.real, **args)
        imag_im = ax[3].imshow(tensor.imag, **args)
        box = ax[1].get_position()
        box.x0 = box.x0 - 0.02
        box.x1 = box.x1 - 0.03
        ax[1].set_position(box)
        box = ax[4].get_position()
        box.x0 = box.x0 - 0.02
        box.x1 = box.x1 - 0.03
        ax[4].set_position(box)
        ax[0].set_title("real");
        ax[3].set_title("imag");
        f.colorbar(real_im, ax[1]);
        f.colorbar(imag_im, ax[4]);
        f.suptitle(title)
        ax[2].remove()
        return f, ax
    else:
        f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)
        im = ax[0].imshow(tensor, **args)
        f.colorbar(im, ax[1])
        f.suptitle(title)
        return f, ax

Параметры данных и модели

In [4]:
CONFIG = type('', (), {})() # object for parameters

# свойства входных данных
CONFIG.classes = 16
CONFIG.image_size = 28
CONFIG.train_class_instances = 8000
CONFIG.test_class_instances = 100
CONFIG.train_data_path = './assets/quickdraw16_train.npy'
CONFIG.test_data_path = './assets/quickdraw16_test.npy'

# свойства модели оптической системы
CONFIG.kernel_size = 28
CONFIG.tile_size_scale_factor = 2
CONFIG.resolution_scale_factor = 2 
CONFIG.class_slots = 16
CONFIG.wavelength = 532e-9
# CONFIG.refractive_index = 1.5090
CONFIG.propagation_distance = 300
CONFIG.metric = 1e-3
CONFIG.pixel_size_meters = 36e-6
CONFIG.layers = 1

pprint(CONFIG.__dict__)
{'class_slots': 16,
 'classes': 16,
 'image_size': 28,
 'kernel_size': 28,
 'layers': 1,
 'metric': 0.001,
 'pixel_size_meters': 3.6e-05,
 'propagation_distance': 300,
 'resolution_scale_factor': 2,
 'test_class_instances': 100,
 'test_data_path': './assets/quickdraw16_test.npy',
 'tile_size_scale_factor': 2,
 'train_class_instances': 8000,
 'train_data_path': './assets/quickdraw16_train.npy',
 'wavelength': 5.32e-07}

Обучающие и тестовые данные

In [5]:
train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)
test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)
train_data = rearrange(train_data, "b (h w) -> b 1 h w", h=CONFIG.image_size, w=CONFIG.image_size)
test_data = rearrange(test_data, "b (h w) -> b 1 h w", h=CONFIG.image_size, w=CONFIG.image_size)
train_data.shape, test_data.shape
Out[5]:
(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))
In [6]:
train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)
test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)

train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances)
test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances)

train_targets.shape, test_targets.shape
Out[6]:
(torch.Size([128000, 16]), torch.Size([1600, 16]))

Модель системы

In [7]:
class OpticalSystem(Module):
    def __init__(self,
                 layers,
                 kernel_size_pixels,
                 tile_size_scale_factor,
                 resolution_scale_factor,
                 class_slots,
                 classes,
                 wavelength = 532e-9, 
                 # refractive_index = 1.5090, 
                 propagation_distance = 300,
                 pixel_size_meters = 36e-6,
                 metric = 1e-3
                ):
        """"""
        super().__init__()
        self.layers = layers
        self.kernel_size_pixels = kernel_size_pixels
        self.tile_size_scale_factor = tile_size_scale_factor
        self.resolution_scale_factor = resolution_scale_factor
        self.class_slots = class_slots
        self.classes = classes
        self.wavelength = wavelength
        # self.refractive_index = refractive_index
        self.propagation_distance = propagation_distance
        self.pixel_size_meters = pixel_size_meters
        self.metric = metric

        assert(self.class_slots >= self.classes)
        self.empty_class_slots = self.class_slots - self.classes 
        
        self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
        self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
        self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
        
        self.height_maps = []
        for i in range(self.layers):
            heights = nn.Parameter(torch.ones([self.phase_mask_size, self.phase_mask_size], dtype=torch.float32))
            torch.nn.init.uniform_(heights, a=0.5*self.wavelength, b=1.5*self.wavelength)
            self.height_maps.append(heights)
        self.height_maps = torch.nn.ParameterList(self.height_maps)
        
        A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
        B = A*self.phase_mask_size/self.tile_size 
        x = torch.linspace(-B, B, self.phase_mask_size+1)[:-1]
        x, y = torch.meshgrid(x, x, indexing='ij')
        kx = torch.linspace(-torch.pi*self.phase_mask_size/2/B, torch.pi*self.phase_mask_size/2/B, self.phase_mask_size+1)[:-1]
        Kx, Ky = torch.meshgrid(kx, kx, indexing='ij')
        vv = torch.arange(0,self.phase_mask_size)
        vv = (-1)**vv
        a, b = torch.meshgrid(vv, vv, indexing='ij')
        lambda1 = self.wavelength / self.metric
        
        self.U = nn.Parameter((Kx**2 + Ky**2).float())
        self.vv = nn.Parameter((a*b).float())
        self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
        self.D = nn.Parameter(torch.exp(-1j*(x**2 + y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
        self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
        self.U.requires_grad = False
        self.vv.requires_grad = False
        self.D.requires_grad = True
        self.coef.requires_grad = False
        

    def propagation(self, field, propagation_distance):
        F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
        return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
    
    def opt_conv(self, inputs, heights):
        result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
        result = result * self.D
        result = self.propagation(field=result, propagation_distance=self.propagation_distance)
        amplitude = torch.sqrt(result.real**2 + result.imag**2)
        return amplitude
    
    def forward(self, image):
        """
        Алгоритм:
        1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
        2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
        3. Моделируется прохождение света через транспаранты
        4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
        5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется (нужна ли нормировка?)
        6. Вектор максимальных значений преобразутся в распределение вероятностей функцией softmax
        """
        # 1
        image = resize(
            image, 
            size=(image.shape[-2]*self.resolution_scale_factor,
                  image.shape[-1]*self.resolution_scale_factor),
            interpolation=InterpolationMode.NEAREST
        )
        # 2
        image = pad_zeros(
            image, 
            size = (self.phase_mask_size , 
                    self.phase_mask_size ),
        )
        # 3      
        x = image 
        for i, plate_heights in enumerate(self.height_maps):  
            x = self.opt_conv(x, plate_heights)
        convolved = x
        # 4
        grid_to_depth = rearrange(
            convolved,
            "b 1 (m ht) (n wt) -> b (m n) ht wt",
            ht = self.tile_size*self.resolution_scale_factor,
            wt = self.tile_size*self.resolution_scale_factor,
            m = self.tiles_per_dim,
            n = self.tiles_per_dim
        )
        # 5
        grid_to_depth = unpad_zeros(grid_to_depth, 
                                    (self.kernel_size_pixels*self.resolution_scale_factor,  
                                     self.kernel_size_pixels*self.resolution_scale_factor))
        max_pool = torch.nn.functional.max_pool2d(
            grid_to_depth,
            kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
        )               
        max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
        max_pool /= max_pool.max(dim=1, keepdims=True).values
        # 6
        softmax = torch.nn.functional.softmax(max_pool, dim=1)
        return softmax, convolved
    
    def __repr__(self):
        tmp = {}
        for k,v in self.__dict__.items():
            if not k[0] == '_':
                tmp[k] = v
        tmp.update(self.__dict__['_modules'])
        tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
        return pformat(tmp, indent=2)

Создание экземпляра модели, оптимизатора, функции потерь

In [8]:
model = OpticalSystem(
     layers = CONFIG.layers,
     kernel_size_pixels = CONFIG.kernel_size,
     tile_size_scale_factor = CONFIG.tile_size_scale_factor,
     resolution_scale_factor = CONFIG.resolution_scale_factor,
     class_slots = CONFIG.class_slots,
     classes = CONFIG.classes,
     wavelength = CONFIG.wavelength, 
     propagation_distance = CONFIG.propagation_distance,
     pixel_size_meters = CONFIG.pixel_size_meters,
     metric = CONFIG.metric
)
# comment to train from scratch
# model.load_state_dict(torch.load(CONFIG.phasemask_model_1_path))
model.eval()
model.cuda()
Out[8]:
{ 'D': 'torch.complex64 torch.Size([448, 448])',
  'U': 'torch.float32 torch.Size([448, 448])',
  'class_slots': 16,
  'classes': 16,
  'coef': 'torch.complex64 torch.Size([1])',
  'empty_class_slots': 0,
  'height_maps': ParameterList(  (0): Parameter containing: [torch.cuda.FloatTensor of size 448x448 (GPU 0)]),
  'k': 'torch.float32 torch.Size([1])',
  'kernel_size_pixels': 28,
  'layers': 1,
  'metric': 0.001,
  'phase_mask_size': 448,
  'pixel_size_meters': 3.6e-05,
  'propagation_distance': 300,
  'resolution_scale_factor': 2,
  'tile_size': 56,
  'tile_size_scale_factor': 2,
  'tiles_per_dim': 4,
  'training': False,
  'vv': 'torch.float32 torch.Size([448, 448])',
  'wavelength': 5.32e-07}
In [9]:
optimizer = torch.optim.Adam(params=model.cuda().parameters(), 
                             lr=1e-2)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
loss_function = torch.nn.CrossEntropyLoss()

Обучение

In [ ]:
# training loop
batch_size = 764
max_passes_through_dataset = 25
epochs = int(train_data.shape[0]/batch_size*max_passes_through_dataset)
ppp = tqdm.trange(epochs)

def init_batch_generator(train_data, train_labels, batch_size):
    """
    Возвращает функцию, вызов которой возвращает следующие batch_size
    примеров и им соответствуюющих меток из train_data, train_labels.
    
    Примеры выбираются последовательно, по кругу. Массивы с входными 
    примерами и метками классов перемешиваются в начале каждого круга.
    """
    def f():
        i = 0
        rnd_indx = torch.randperm(train_data.shape[0])
        train_data_shuffled = train_data[rnd_indx]
        train_labels_shuffled = train_labels[rnd_indx]
        while True:
            if i + batch_size > train_data.shape[0]:
                i = 0
                rnd_indx = torch.randperm(train_data.shape[0])
                train_data_shuffled = train_data[rnd_indx]
                train_labels_shuffled = train_labels[rnd_indx]
                
            batch_inputs = train_data_shuffled[i:i+batch_size]
            batch_targets = train_labels_shuffled[i:i+batch_size]
            i = i + batch_size
            yield batch_inputs, batch_targets
    return f()

batch_iterator = init_batch_generator(train_data, train_targets, batch_size)
i = 0

for epoch in ppp:
  batch_inputs, batch_targets = next(batch_iterator)
  batch_inputs = batch_inputs.cuda()
  batch_targets = batch_targets.cuda()
  i = i + batch_size
  passes_through_dataset = i//train_data.shape[0]
  # apply model
  predicted, convolved = model(batch_inputs)

  # correct model
  loss_value = loss_function(predicted, batch_targets)

  loss_value.backward()
  optimizer.step()

  # для небольших батчей следует уменьшать частоту вывода 
  if epoch % 2 == 0:
    acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted))
    ppp.set_postfix_str("loss: {:e}, acc: {:.2f}, lr: {:e}, passes_through_dataset: {}".format(loss_value, acc, scheduler.get_last_lr()[0], passes_through_dataset))
    
  if (scheduler.get_last_lr()[0] > 1e-13):
    scheduler.step()
 97%|█████████▋| 4071/4188 [11:23<00:19,  5.97it/s, loss: 2.736057e+00, acc: 0.69, lr: 1.704265e-04, passes_through_dataset: 24]

Тест

In [ ]:
inputs = test_data
targets = test_targets
batch_size = 64

predicted = []
batch_start = 0
while batch_start < test_data.shape[0]:
    batch_end = min(batch_start + batch_size, test_data.shape[0])
    batch_input = inputs[batch_start:batch_end].cuda() 
    batch_output, _ = model(batch_input)
    predicted.append(batch_output.detach().cpu())
    batch_start = batch_end

predicted = torch.concat(predicted)

test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))
"Accuracy on test dataset: ", test_acc
In [ ]:
imshow(model.height_maps, figsize=(10,10))
In [ ]:
class_id = 3
image = test_data[test_labels==class_id][:1]
imshow(image, title=f"Input images")
softmax, convolved = model(image.cuda())

for idx, psf in enumerate(convolved):
    psf = psf.squeeze()
    f, ax = imshow(psf, figsize=(5,5), title=f"Result of optical convolution with phase plate for image {idx}")
    ax[0].hlines(np.arange(0, psf.shape[0], psf.shape[0]//model.tiles_per_dim), 0, psf.shape[1]-1)
    ax[0].vlines(np.arange(0, psf.shape[1], psf.shape[1]//model.tiles_per_dim), 0, psf.shape[0]-1)
    y,x = (psf==torch.max(psf)).nonzero()[0]
    ax[0].text(x,y, "max", color='white');

Сохранение рельефа

In [ ]:
# from PIL import Image

# for idx, heights in enumerate(model.height_maps):
#     m = heights.abs().mean()
#     s = heights.abs().std()
#     m1, m2 = heights.abs().min(), heights.abs().max()
#     ar = heights.abs().cpu().detach().numpy() 
#     print(ar.dtype)
#     im = ar
#     im = im - im.min()
#     im = im / im.max()
#     im = im * 255
#     name_im = f"phasemask_{idx}.png"
#     name_np = f"phasemask_{idx}"
#     result = Image.fromarray(im.astype(np.uint8))
#     result.save(name_im)
#     np.save(name_np, ar)