You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 KiB
26 KiB
Пример на данных QuckDraw¶
https://github.com/googlecreativelab/quickdraw-dataset
Используется урезанная версия с 16 классами
In [1]:
import torch from torch import flip from torch.nn import Module from torch import nn from torch.nn.functional import conv2d import torch.nn.functional as F from torchvision.transforms.functional import resize, InterpolationMode from einops import rearrange import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix, accuracy_score import numpy as np import math import tqdm from pprint import pprint, pformat
/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Служебные функции¶
In [2]:
# from utils import circular_aperture, imshow, pad_zeros, to_class_labels def pad_zeros(input, size): h, w = input.shape[-2:] th, tw = size if len(input.shape) == 2: gg = torch.zeros(size, device=input.device) x, y = int(th/2 - h/2), int(tw/2 - w/2) gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:] if len(input.shape) == 4: gg = torch.zeros(input.shape[:2] + size, device=input.device) x, y = int(th/2 - h/2), int(tw/2 - w/2) gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:] return gg def unpad_zeros(input, size): h, w = input.shape[-2:] th, tw = size dx,dy = h-th, w-tw if len(input.shape) == 2: gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)] if len(input.shape) == 4: gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw] return gg def to_class_labels(softmax_distibutions): return torch.argmax(softmax_distibutions, dim=1).cpu()
In [3]:
def imshow(tensor, figsize=None, title="", **args): tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor figsize = figsize if figsize else (13*0.8,5*0.8) if type(tensor) is list: for idx, el in enumerate(tensor): imshow(el, figsize=figsize, title=title, **args) plt.suptitle("{} {}".format(idx, title)) return if len(tensor.shape)==4: for idx, el in enumerate(torch.squeeze(tensor, dim=1)): imshow(el, figsize=figsize, title=title, **args) plt.suptitle("{} {}".format(idx, title)) return if tensor.dtype == torch.complex64: f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]}) real_im = ax[0].imshow(tensor.real, **args) imag_im = ax[3].imshow(tensor.imag, **args) box = ax[1].get_position() box.x0 = box.x0 - 0.02 box.x1 = box.x1 - 0.03 ax[1].set_position(box) box = ax[4].get_position() box.x0 = box.x0 - 0.02 box.x1 = box.x1 - 0.03 ax[4].set_position(box) ax[0].set_title("real"); ax[3].set_title("imag"); f.colorbar(real_im, ax[1]); f.colorbar(imag_im, ax[4]); f.suptitle(title) ax[2].remove() return f, ax else: f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize) im = ax[0].imshow(tensor, **args) f.colorbar(im, ax[1]) f.suptitle(title) return f, ax
Параметры данных и модели¶
In [4]:
CONFIG = type('', (), {})() # object for parameters # свойства входных данных CONFIG.classes = 16 CONFIG.image_size = 28 CONFIG.train_class_instances = 8000 CONFIG.test_class_instances = 100 CONFIG.train_data_path = './assets/quickdraw16_train.npy' CONFIG.test_data_path = './assets/quickdraw16_test.npy' # свойства модели оптической системы CONFIG.kernel_size = 28 CONFIG.tile_size_scale_factor = 2 CONFIG.resolution_scale_factor = 2 CONFIG.class_slots = 16 CONFIG.wavelength = 532e-9 # CONFIG.refractive_index = 1.5090 CONFIG.propagation_distance = 300 CONFIG.metric = 1e-3 CONFIG.pixel_size_meters = 36e-6 CONFIG.layers = 1 pprint(CONFIG.__dict__)
{'class_slots': 16, 'classes': 16, 'image_size': 28, 'kernel_size': 28, 'layers': 1, 'metric': 0.001, 'pixel_size_meters': 3.6e-05, 'propagation_distance': 300, 'resolution_scale_factor': 2, 'test_class_instances': 100, 'test_data_path': './assets/quickdraw16_test.npy', 'tile_size_scale_factor': 2, 'train_class_instances': 8000, 'train_data_path': './assets/quickdraw16_train.npy', 'wavelength': 5.32e-07}
Обучающие и тестовые данные¶
In [5]:
train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32) test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32) train_data = rearrange(train_data, "b (h w) -> b 1 h w", h=CONFIG.image_size, w=CONFIG.image_size) test_data = rearrange(test_data, "b (h w) -> b 1 h w", h=CONFIG.image_size, w=CONFIG.image_size) train_data.shape, test_data.shape
Out[5]:
(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))
In [6]:
train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0) test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0) train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances) test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances) train_targets.shape, test_targets.shape
Out[6]:
(torch.Size([128000, 16]), torch.Size([1600, 16]))
Модель системы¶
In [7]:
class OpticalSystem(Module): def __init__(self, layers, kernel_size_pixels, tile_size_scale_factor, resolution_scale_factor, class_slots, classes, wavelength = 532e-9, # refractive_index = 1.5090, propagation_distance = 300, pixel_size_meters = 36e-6, metric = 1e-3 ): """""" super().__init__() self.layers = layers self.kernel_size_pixels = kernel_size_pixels self.tile_size_scale_factor = tile_size_scale_factor self.resolution_scale_factor = resolution_scale_factor self.class_slots = class_slots self.classes = classes self.wavelength = wavelength # self.refractive_index = refractive_index self.propagation_distance = propagation_distance self.pixel_size_meters = pixel_size_meters self.metric = metric assert(self.class_slots >= self.classes) self.empty_class_slots = self.class_slots - self.classes self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32) self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor self.height_maps = [] for i in range(self.layers): heights = nn.Parameter(torch.ones([self.phase_mask_size, self.phase_mask_size], dtype=torch.float32)) torch.nn.init.uniform_(heights, a=0.5*self.wavelength, b=1.5*self.wavelength) self.height_maps.append(heights) self.height_maps = torch.nn.ParameterList(self.height_maps) A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric B = A*self.phase_mask_size/self.tile_size x = torch.linspace(-B, B, self.phase_mask_size+1)[:-1] x, y = torch.meshgrid(x, x, indexing='ij') kx = torch.linspace(-torch.pi*self.phase_mask_size/2/B, torch.pi*self.phase_mask_size/2/B, self.phase_mask_size+1)[:-1] Kx, Ky = torch.meshgrid(kx, kx, indexing='ij') vv = torch.arange(0,self.phase_mask_size) vv = (-1)**vv a, b = torch.meshgrid(vv, vv, indexing='ij') lambda1 = self.wavelength / self.metric self.U = nn.Parameter((Kx**2 + Ky**2).float()) self.vv = nn.Parameter((a*b).float()) self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1])) self.D = nn.Parameter(torch.exp(-1j*(x**2 + y**2)/self.resolution_scale_factor/self.propagation_distance*self.k)) self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k])) self.U.requires_grad = False self.vv.requires_grad = False self.D.requires_grad = True self.coef.requires_grad = False def propagation(self, field, propagation_distance): F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k) return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv def opt_conv(self, inputs, heights): result = self.propagation(field=inputs, propagation_distance=self.propagation_distance) result = result * self.D result = self.propagation(field=result, propagation_distance=self.propagation_distance) amplitude = torch.sqrt(result.real**2 + result.imag**2) return amplitude def forward(self, image): """ Алгоритм: 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56] 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448] 3. Моделируется прохождение света через транспаранты 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется (нужна ли нормировка?) 6. Вектор максимальных значений преобразутся в распределение вероятностей функцией softmax """ # 1 image = resize( image, size=(image.shape[-2]*self.resolution_scale_factor, image.shape[-1]*self.resolution_scale_factor), interpolation=InterpolationMode.NEAREST ) # 2 image = pad_zeros( image, size = (self.phase_mask_size , self.phase_mask_size ), ) # 3 x = image for i, plate_heights in enumerate(self.height_maps): x = self.opt_conv(x, plate_heights) convolved = x # 4 grid_to_depth = rearrange( convolved, "b 1 (m ht) (n wt) -> b (m n) ht wt", ht = self.tile_size*self.resolution_scale_factor, wt = self.tile_size*self.resolution_scale_factor, m = self.tiles_per_dim, n = self.tiles_per_dim ) # 5 grid_to_depth = unpad_zeros(grid_to_depth, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) max_pool = torch.nn.functional.max_pool2d( grid_to_depth, kernel_size = self.kernel_size_pixels*self.resolution_scale_factor ) max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots) max_pool /= max_pool.max(dim=1, keepdims=True).values # 6 softmax = torch.nn.functional.softmax(max_pool, dim=1) return softmax, convolved def __repr__(self): tmp = {} for k,v in self.__dict__.items(): if not k[0] == '_': tmp[k] = v tmp.update(self.__dict__['_modules']) tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()}) return pformat(tmp, indent=2)
Создание экземпляра модели, оптимизатора, функции потерь¶
In [8]:
model = OpticalSystem( layers = CONFIG.layers, kernel_size_pixels = CONFIG.kernel_size, tile_size_scale_factor = CONFIG.tile_size_scale_factor, resolution_scale_factor = CONFIG.resolution_scale_factor, class_slots = CONFIG.class_slots, classes = CONFIG.classes, wavelength = CONFIG.wavelength, propagation_distance = CONFIG.propagation_distance, pixel_size_meters = CONFIG.pixel_size_meters, metric = CONFIG.metric ) # comment to train from scratch # model.load_state_dict(torch.load(CONFIG.phasemask_model_1_path)) model.eval() model.cuda()
Out[8]:
{ 'D': 'torch.complex64 torch.Size([448, 448])', 'U': 'torch.float32 torch.Size([448, 448])', 'class_slots': 16, 'classes': 16, 'coef': 'torch.complex64 torch.Size([1])', 'empty_class_slots': 0, 'height_maps': ParameterList( (0): Parameter containing: [torch.cuda.FloatTensor of size 448x448 (GPU 0)]), 'k': 'torch.float32 torch.Size([1])', 'kernel_size_pixels': 28, 'layers': 1, 'metric': 0.001, 'phase_mask_size': 448, 'pixel_size_meters': 3.6e-05, 'propagation_distance': 300, 'resolution_scale_factor': 2, 'tile_size': 56, 'tile_size_scale_factor': 2, 'tiles_per_dim': 4, 'training': False, 'vv': 'torch.float32 torch.Size([448, 448])', 'wavelength': 5.32e-07}
In [9]:
optimizer = torch.optim.Adam(params=model.cuda().parameters(), lr=1e-2) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999) loss_function = torch.nn.CrossEntropyLoss()
Обучение¶
In [ ]:
# training loop batch_size = 764 max_passes_through_dataset = 25 epochs = int(train_data.shape[0]/batch_size*max_passes_through_dataset) ppp = tqdm.trange(epochs) def init_batch_generator(train_data, train_labels, batch_size): """ Возвращает функцию, вызов которой возвращает следующие batch_size примеров и им соответствуюющих меток из train_data, train_labels. Примеры выбираются последовательно, по кругу. Массивы с входными примерами и метками классов перемешиваются в начале каждого круга. """ def f(): i = 0 rnd_indx = torch.randperm(train_data.shape[0]) train_data_shuffled = train_data[rnd_indx] train_labels_shuffled = train_labels[rnd_indx] while True: if i + batch_size > train_data.shape[0]: i = 0 rnd_indx = torch.randperm(train_data.shape[0]) train_data_shuffled = train_data[rnd_indx] train_labels_shuffled = train_labels[rnd_indx] batch_inputs = train_data_shuffled[i:i+batch_size] batch_targets = train_labels_shuffled[i:i+batch_size] i = i + batch_size yield batch_inputs, batch_targets return f() batch_iterator = init_batch_generator(train_data, train_targets, batch_size) i = 0 for epoch in ppp: batch_inputs, batch_targets = next(batch_iterator) batch_inputs = batch_inputs.cuda() batch_targets = batch_targets.cuda() i = i + batch_size passes_through_dataset = i//train_data.shape[0] # apply model predicted, convolved = model(batch_inputs) # correct model loss_value = loss_function(predicted, batch_targets) loss_value.backward() optimizer.step() # для небольших батчей следует уменьшать частоту вывода if epoch % 2 == 0: acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted)) ppp.set_postfix_str("loss: {:e}, acc: {:.2f}, lr: {:e}, passes_through_dataset: {}".format(loss_value, acc, scheduler.get_last_lr()[0], passes_through_dataset)) if (scheduler.get_last_lr()[0] > 1e-13): scheduler.step()
97%|█████████▋| 4071/4188 [11:23<00:19, 5.97it/s, loss: 2.736057e+00, acc: 0.69, lr: 1.704265e-04, passes_through_dataset: 24]
Тест¶
In [ ]:
inputs = test_data targets = test_targets batch_size = 64 predicted = [] batch_start = 0 while batch_start < test_data.shape[0]: batch_end = min(batch_start + batch_size, test_data.shape[0]) batch_input = inputs[batch_start:batch_end].cuda() batch_output, _ = model(batch_input) predicted.append(batch_output.detach().cpu()) batch_start = batch_end predicted = torch.concat(predicted) test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted)) "Accuracy on test dataset: ", test_acc
In [ ]:
imshow(model.height_maps, figsize=(10,10))
In [ ]:
class_id = 3 image = test_data[test_labels==class_id][:1] imshow(image, title=f"Input images") softmax, convolved = model(image.cuda()) for idx, psf in enumerate(convolved): psf = psf.squeeze() f, ax = imshow(psf, figsize=(5,5), title=f"Result of optical convolution with phase plate for image {idx}") ax[0].hlines(np.arange(0, psf.shape[0], psf.shape[0]//model.tiles_per_dim), 0, psf.shape[1]-1) ax[0].vlines(np.arange(0, psf.shape[1], psf.shape[1]//model.tiles_per_dim), 0, psf.shape[0]-1) y,x = (psf==torch.max(psf)).nonzero()[0] ax[0].text(x,y, "max", color='white');
Сохранение рельефа¶
In [ ]:
# from PIL import Image # for idx, heights in enumerate(model.height_maps): # m = heights.abs().mean() # s = heights.abs().std() # m1, m2 = heights.abs().min(), heights.abs().max() # ar = heights.abs().cpu().detach().numpy() # print(ar.dtype) # im = ar # im = im - im.min() # im = im / im.max() # im = im * 255 # name_im = f"phasemask_{idx}.png" # name_np = f"phasemask_{idx}" # result = Image.fromarray(im.astype(np.uint8)) # result.save(name_im) # np.save(name_np, ar)