You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
optical_accelerator/Propagator_QuickDraw.ipynb

701 lines
26 KiB
Plaintext

5 months ago
{
"cells": [
{
"cell_type": "markdown",
"id": "f0d9d491-74b8-4567-b2ba-79e99ab499ee",
"metadata": {},
"source": [
"# Пример на данных QuckDraw \n",
"https://github.com/googlecreativelab/quickdraw-dataset \n",
"Используется урезанная версия с 16 классами "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c4455b6a-5ed9-4499-98b4-42e18fa1f6d8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"from torch import flip\n",
"from torch.nn import Module\n",
"from torch import nn\n",
"from torch.nn.functional import conv2d\n",
"import torch.nn.functional as F\n",
"from torchvision.transforms.functional import resize, InterpolationMode\n",
"from einops import rearrange\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
"\n",
"import numpy as np\n",
"import math\n",
"\n",
"import tqdm\n",
"from pprint import pprint, pformat"
]
},
{
"cell_type": "markdown",
"id": "32211b41-628a-4376-adf3-e93cc58bc2b4",
"metadata": {},
"source": [
"### Служебные функции"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9aefa1b1-e4da-4e02-b691-7745f9759e5a",
"metadata": {},
"outputs": [],
"source": [
"# from utils import circular_aperture, imshow, pad_zeros, to_class_labels\n",
"def pad_zeros(input, size):\n",
" h, w = input.shape[-2:]\n",
" th, tw = size\n",
" \n",
" if len(input.shape) == 2:\n",
" gg = torch.zeros(size, device=input.device)\n",
" x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
" gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:]\n",
" \n",
" if len(input.shape) == 4:\n",
" gg = torch.zeros(input.shape[:2] + size, device=input.device)\n",
" x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
" gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]\n",
"\n",
" return gg\n",
"\n",
"def unpad_zeros(input, size):\n",
" h, w = input.shape[-2:]\n",
" th, tw = size\n",
" dx,dy = h-th, w-tw\n",
" \n",
" if len(input.shape) == 2: \n",
" gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]\n",
" \n",
" if len(input.shape) == 4:\n",
" gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]\n",
" return gg\n",
"\n",
"def to_class_labels(softmax_distibutions):\n",
" return torch.argmax(softmax_distibutions, dim=1).cpu()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3d755b73-bac0-47ce-a2fb-6ade6428445e",
"metadata": {},
"outputs": [],
"source": [
"def imshow(tensor, figsize=None, title=\"\", **args):\n",
" tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor\n",
" tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor\n",
" \n",
" figsize = figsize if figsize else (13*0.8,5*0.8)\n",
" \n",
" if type(tensor) is list:\n",
" for idx, el in enumerate(tensor):\n",
" imshow(el, figsize=figsize, title=title, **args)\n",
" plt.suptitle(\"{} {}\".format(idx, title))\n",
" return\n",
" if len(tensor.shape)==4:\n",
" for idx, el in enumerate(torch.squeeze(tensor, dim=1)):\n",
" imshow(el, figsize=figsize, title=title, **args)\n",
" plt.suptitle(\"{} {}\".format(idx, title))\n",
" return\n",
" \n",
" if tensor.dtype == torch.complex64:\n",
" f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]})\n",
" real_im = ax[0].imshow(tensor.real, **args)\n",
" imag_im = ax[3].imshow(tensor.imag, **args)\n",
" box = ax[1].get_position()\n",
" box.x0 = box.x0 - 0.02\n",
" box.x1 = box.x1 - 0.03\n",
" ax[1].set_position(box)\n",
" box = ax[4].get_position()\n",
" box.x0 = box.x0 - 0.02\n",
" box.x1 = box.x1 - 0.03\n",
" ax[4].set_position(box)\n",
" ax[0].set_title(\"real\");\n",
" ax[3].set_title(\"imag\");\n",
" f.colorbar(real_im, ax[1]);\n",
" f.colorbar(imag_im, ax[4]);\n",
" f.suptitle(title)\n",
" ax[2].remove()\n",
" return f, ax\n",
" else:\n",
" f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)\n",
" im = ax[0].imshow(tensor, **args)\n",
" f.colorbar(im, ax[1])\n",
" f.suptitle(title)\n",
" return f, ax"
]
},
{
"cell_type": "markdown",
"id": "cb1d1521-bade-47f6-a2e8-1d12c857e2f8",
"metadata": {},
"source": [
"### Параметры данных и модели"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b9ce4212-977c-4621-850b-db595cd00ab2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'class_slots': 16,\n",
" 'classes': 16,\n",
" 'image_size': 28,\n",
" 'kernel_size': 28,\n",
" 'layers': 1,\n",
" 'metric': 0.001,\n",
" 'pixel_size_meters': 3.6e-05,\n",
" 'propagation_distance': 300,\n",
" 'resolution_scale_factor': 2,\n",
" 'test_class_instances': 100,\n",
" 'test_data_path': './assets/quickdraw16_test.npy',\n",
" 'tile_size_scale_factor': 2,\n",
" 'train_class_instances': 8000,\n",
" 'train_data_path': './assets/quickdraw16_train.npy',\n",
" 'wavelength': 5.32e-07}\n"
]
}
],
"source": [
"CONFIG = type('', (), {})() # object for parameters\n",
"\n",
"# свойства входных данных\n",
"CONFIG.classes = 16\n",
"CONFIG.image_size = 28\n",
"CONFIG.train_class_instances = 8000\n",
"CONFIG.test_class_instances = 100\n",
"CONFIG.train_data_path = './assets/quickdraw16_train.npy'\n",
"CONFIG.test_data_path = './assets/quickdraw16_test.npy'\n",
"\n",
"# свойства модели оптической системы\n",
"CONFIG.kernel_size = 28\n",
"CONFIG.tile_size_scale_factor = 2\n",
"CONFIG.resolution_scale_factor = 2 \n",
"CONFIG.class_slots = 16\n",
"CONFIG.wavelength = 532e-9\n",
"# CONFIG.refractive_index = 1.5090\n",
"CONFIG.propagation_distance = 300\n",
"CONFIG.metric = 1e-3\n",
"CONFIG.pixel_size_meters = 36e-6\n",
"CONFIG.layers = 1\n",
"\n",
"pprint(CONFIG.__dict__)"
]
},
{
"cell_type": "markdown",
"id": "41470dee-088e-4d24-8a1f-4b18636b4dac",
"metadata": {},
"source": [
"### Обучающие и тестовые данные"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bb84c1e5-0201-4815-bf88-d5512364e731",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)\n",
"test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)\n",
"train_data = rearrange(train_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
"test_data = rearrange(test_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
"train_data.shape, test_data.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2684aad3-16db-41a9-914b-762151d865a3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([128000, 16]), torch.Size([1600, 16]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)\n",
"test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)\n",
"\n",
"train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances)\n",
"test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances)\n",
"\n",
"train_targets.shape, test_targets.shape"
]
},
{
"cell_type": "markdown",
"id": "037548dd-9834-4c6d-b1dd-ced4d5f4adaa",
"metadata": {},
"source": [
"### Модель системы"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f116e1f4-e477-4168-bac6-9a21f2ac4190",
"metadata": {},
"outputs": [],
"source": [
"class OpticalSystem(Module):\n",
" def __init__(self,\n",
" layers,\n",
" kernel_size_pixels,\n",
" tile_size_scale_factor,\n",
" resolution_scale_factor,\n",
" class_slots,\n",
" classes,\n",
" wavelength = 532e-9, \n",
" # refractive_index = 1.5090, \n",
" propagation_distance = 300,\n",
" pixel_size_meters = 36e-6,\n",
" metric = 1e-3\n",
" ):\n",
" \"\"\"\"\"\"\n",
" super().__init__()\n",
" self.layers = layers\n",
" self.kernel_size_pixels = kernel_size_pixels\n",
" self.tile_size_scale_factor = tile_size_scale_factor\n",
" self.resolution_scale_factor = resolution_scale_factor\n",
" self.class_slots = class_slots\n",
" self.classes = classes\n",
" self.wavelength = wavelength\n",
" # self.refractive_index = refractive_index\n",
" self.propagation_distance = propagation_distance\n",
" self.pixel_size_meters = pixel_size_meters\n",
" self.metric = metric\n",
"\n",
" assert(self.class_slots >= self.classes)\n",
" self.empty_class_slots = self.class_slots - self.classes \n",
" \n",
" self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor\n",
" self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)\n",
" self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor\n",
" \n",
" self.height_maps = []\n",
" for i in range(self.layers):\n",
" heights = nn.Parameter(torch.ones([self.phase_mask_size, self.phase_mask_size], dtype=torch.float32))\n",
" torch.nn.init.uniform_(heights, a=0.5*self.wavelength, b=1.5*self.wavelength)\n",
" self.height_maps.append(heights)\n",
" self.height_maps = torch.nn.ParameterList(self.height_maps)\n",
" \n",
" A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric\n",
" B = A*self.phase_mask_size/self.tile_size \n",
" x = torch.linspace(-B, B, self.phase_mask_size+1)[:-1]\n",
" x, y = torch.meshgrid(x, x, indexing='ij')\n",
" kx = torch.linspace(-torch.pi*self.phase_mask_size/2/B, torch.pi*self.phase_mask_size/2/B, self.phase_mask_size+1)[:-1]\n",
" Kx, Ky = torch.meshgrid(kx, kx, indexing='ij')\n",
" vv = torch.arange(0,self.phase_mask_size)\n",
" vv = (-1)**vv\n",
" a, b = torch.meshgrid(vv, vv, indexing='ij')\n",
" lambda1 = self.wavelength / self.metric\n",
" \n",
" self.U = nn.Parameter((Kx**2 + Ky**2).float())\n",
" self.vv = nn.Parameter((a*b).float())\n",
" self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))\n",
" self.D = nn.Parameter(torch.exp(-1j*(x**2 + y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))\n",
" self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))\n",
" self.U.requires_grad = False\n",
" self.vv.requires_grad = False\n",
" self.D.requires_grad = True\n",
" self.coef.requires_grad = False\n",
" \n",
"\n",
" def propagation(self, field, propagation_distance):\n",
" F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)\n",
" return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv\n",
" \n",
" def opt_conv(self, inputs, heights):\n",
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
" result = result * self.D\n",
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
" return amplitude\n",
" \n",
" def forward(self, image):\n",
" \"\"\"\n",
" Алгоритм:\n",
" 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]\n",
" 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]\n",
" 3. Моделируется прохождение света через транспаранты\n",
" 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim\n",
" 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется (нужна ли нормировка?)\n",
" 6. Вектор максимальных значений преобразутся в распределение вероятностей функцией softmax\n",
" \"\"\"\n",
" # 1\n",
" image = resize(\n",
" image, \n",
" size=(image.shape[-2]*self.resolution_scale_factor,\n",
" image.shape[-1]*self.resolution_scale_factor),\n",
" interpolation=InterpolationMode.NEAREST\n",
" )\n",
" # 2\n",
" image = pad_zeros(\n",
" image, \n",
" size = (self.phase_mask_size , \n",
" self.phase_mask_size ),\n",
" )\n",
" # 3 \n",
" x = image \n",
" for i, plate_heights in enumerate(self.height_maps): \n",
" x = self.opt_conv(x, plate_heights)\n",
" convolved = x\n",
" # 4\n",
" grid_to_depth = rearrange(\n",
" convolved,\n",
" \"b 1 (m ht) (n wt) -> b (m n) ht wt\",\n",
" ht = self.tile_size*self.resolution_scale_factor,\n",
" wt = self.tile_size*self.resolution_scale_factor,\n",
" m = self.tiles_per_dim,\n",
" n = self.tiles_per_dim\n",
" )\n",
" # 5\n",
" grid_to_depth = unpad_zeros(grid_to_depth, \n",
" (self.kernel_size_pixels*self.resolution_scale_factor, \n",
" self.kernel_size_pixels*self.resolution_scale_factor))\n",
" max_pool = torch.nn.functional.max_pool2d(\n",
" grid_to_depth,\n",
" kernel_size = self.kernel_size_pixels*self.resolution_scale_factor\n",
" ) \n",
" max_pool = rearrange(max_pool, \"b class_slots 1 1 -> b class_slots\", class_slots=self.class_slots)\n",
" max_pool /= max_pool.max(dim=1, keepdims=True).values\n",
" # 6\n",
" softmax = torch.nn.functional.softmax(max_pool, dim=1)\n",
" return softmax, convolved\n",
" \n",
" def __repr__(self):\n",
" tmp = {}\n",
" for k,v in self.__dict__.items():\n",
" if not k[0] == '_':\n",
" tmp[k] = v\n",
" tmp.update(self.__dict__['_modules'])\n",
" tmp.update({k:f\"{v.dtype} {v.shape}\" for k,v in self.__dict__['_parameters'].items()})\n",
" return pformat(tmp, indent=2)"
]
},
{
"cell_type": "markdown",
"id": "200b86bb-188c-4034-8a98-1bcfb5b92d9e",
"metadata": {},
"source": [
"### Создание экземпляра модели, оптимизатора, функции потерь"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "12c87b56-bdcf-4e9c-a171-b096ff244376",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{ 'D': 'torch.complex64 torch.Size([448, 448])',\n",
" 'U': 'torch.float32 torch.Size([448, 448])',\n",
" 'class_slots': 16,\n",
" 'classes': 16,\n",
" 'coef': 'torch.complex64 torch.Size([1])',\n",
" 'empty_class_slots': 0,\n",
" 'height_maps': ParameterList( (0): Parameter containing: [torch.cuda.FloatTensor of size 448x448 (GPU 0)]),\n",
" 'k': 'torch.float32 torch.Size([1])',\n",
" 'kernel_size_pixels': 28,\n",
" 'layers': 1,\n",
" 'metric': 0.001,\n",
" 'phase_mask_size': 448,\n",
" 'pixel_size_meters': 3.6e-05,\n",
" 'propagation_distance': 300,\n",
" 'resolution_scale_factor': 2,\n",
" 'tile_size': 56,\n",
" 'tile_size_scale_factor': 2,\n",
" 'tiles_per_dim': 4,\n",
" 'training': False,\n",
" 'vv': 'torch.float32 torch.Size([448, 448])',\n",
" 'wavelength': 5.32e-07}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = OpticalSystem(\n",
" layers = CONFIG.layers,\n",
" kernel_size_pixels = CONFIG.kernel_size,\n",
" tile_size_scale_factor = CONFIG.tile_size_scale_factor,\n",
" resolution_scale_factor = CONFIG.resolution_scale_factor,\n",
" class_slots = CONFIG.class_slots,\n",
" classes = CONFIG.classes,\n",
" wavelength = CONFIG.wavelength, \n",
" propagation_distance = CONFIG.propagation_distance,\n",
" pixel_size_meters = CONFIG.pixel_size_meters,\n",
" metric = CONFIG.metric\n",
")\n",
"# comment to train from scratch\n",
"# model.load_state_dict(torch.load(CONFIG.phasemask_model_1_path))\n",
"model.eval()\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "43dcf5cb-a10a-4f02-811a-1c365e2e1bed",
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(params=model.cuda().parameters(), \n",
" lr=1e-2)\n",
"scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n",
"loss_function = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"id": "d6f489d1-0d5d-43ec-bcaf-f01a82c28bf1",
"metadata": {},
"source": [
"### Обучение"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3851a9d-1867-4c5e-b4ec-3598448a054f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 97%|█████████▋| 4071/4188 [11:23<00:19, 5.97it/s, loss: 2.736057e+00, acc: 0.69, lr: 1.704265e-04, passes_through_dataset: 24]"
]
}
],
"source": [
"# training loop\n",
"batch_size = 764\n",
"max_passes_through_dataset = 25\n",
"epochs = int(train_data.shape[0]/batch_size*max_passes_through_dataset)\n",
"ppp = tqdm.trange(epochs)\n",
"\n",
"def init_batch_generator(train_data, train_labels, batch_size):\n",
" \"\"\"\n",
" Возвращает функцию, вызов которой возвращает следующие batch_size\n",
" примеров и им соответствуюющих меток из train_data, train_labels.\n",
" \n",
" Примеры выбираются последовательно, по кругу. Массивы с входными \n",
" примерами и метками классов перемешиваются в начале каждого круга.\n",
" \"\"\"\n",
" def f():\n",
" i = 0\n",
" rnd_indx = torch.randperm(train_data.shape[0])\n",
" train_data_shuffled = train_data[rnd_indx]\n",
" train_labels_shuffled = train_labels[rnd_indx]\n",
" while True:\n",
" if i + batch_size > train_data.shape[0]:\n",
" i = 0\n",
" rnd_indx = torch.randperm(train_data.shape[0])\n",
" train_data_shuffled = train_data[rnd_indx]\n",
" train_labels_shuffled = train_labels[rnd_indx]\n",
" \n",
" batch_inputs = train_data_shuffled[i:i+batch_size]\n",
" batch_targets = train_labels_shuffled[i:i+batch_size]\n",
" i = i + batch_size\n",
" yield batch_inputs, batch_targets\n",
" return f()\n",
"\n",
"batch_iterator = init_batch_generator(train_data, train_targets, batch_size)\n",
"i = 0\n",
"\n",
"for epoch in ppp:\n",
" batch_inputs, batch_targets = next(batch_iterator)\n",
" batch_inputs = batch_inputs.cuda()\n",
" batch_targets = batch_targets.cuda()\n",
" i = i + batch_size\n",
" passes_through_dataset = i//train_data.shape[0]\n",
" # apply model\n",
" predicted, convolved = model(batch_inputs)\n",
"\n",
" # correct model\n",
" loss_value = loss_function(predicted, batch_targets)\n",
"\n",
" loss_value.backward()\n",
" optimizer.step()\n",
"\n",
" # для небольших батчей следует уменьшать частоту вывода \n",
" if epoch % 2 == 0:\n",
" acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted))\n",
" ppp.set_postfix_str(\"loss: {:e}, acc: {:.2f}, lr: {:e}, passes_through_dataset: {}\".format(loss_value, acc, scheduler.get_last_lr()[0], passes_through_dataset))\n",
" \n",
" if (scheduler.get_last_lr()[0] > 1e-13):\n",
" scheduler.step()"
]
},
{
"cell_type": "markdown",
"id": "ae4259bb-199b-43ee-951a-4d481dd1cddc",
"metadata": {},
"source": [
"### Тест"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f68545e2-7f65-464e-94ef-cee7ffd0236a",
"metadata": {},
"outputs": [],
"source": [
"inputs = test_data\n",
"targets = test_targets\n",
"batch_size = 64\n",
"\n",
"predicted = []\n",
"batch_start = 0\n",
"while batch_start < test_data.shape[0]:\n",
" batch_end = min(batch_start + batch_size, test_data.shape[0])\n",
" batch_input = inputs[batch_start:batch_end].cuda() \n",
" batch_output, _ = model(batch_input)\n",
" predicted.append(batch_output.detach().cpu())\n",
" batch_start = batch_end\n",
"\n",
"predicted = torch.concat(predicted)\n",
"\n",
"test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
"\"Accuracy on test dataset: \", test_acc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "296f2dab-6f4a-40b1-81ea-a1e7ed7e72ba",
"metadata": {},
"outputs": [],
"source": [
"imshow(model.height_maps, figsize=(10,10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a4cd483-be14-4222-b3c6-0fadd0b1c017",
"metadata": {},
"outputs": [],
"source": [
"class_id = 3\n",
"image = test_data[test_labels==class_id][:1]\n",
"imshow(image, title=f\"Input images\")\n",
"softmax, convolved = model(image.cuda())\n",
"\n",
"for idx, psf in enumerate(convolved):\n",
" psf = psf.squeeze()\n",
" f, ax = imshow(psf, figsize=(5,5), title=f\"Result of optical convolution with phase plate for image {idx}\")\n",
" ax[0].hlines(np.arange(0, psf.shape[0], psf.shape[0]//model.tiles_per_dim), 0, psf.shape[1]-1)\n",
" ax[0].vlines(np.arange(0, psf.shape[1], psf.shape[1]//model.tiles_per_dim), 0, psf.shape[0]-1)\n",
" y,x = (psf==torch.max(psf)).nonzero()[0]\n",
" ax[0].text(x,y, \"max\", color='white');"
]
},
{
"cell_type": "markdown",
"id": "63f39f87-5a98-454b-9f1c-fdc712b11f0b",
"metadata": {},
"source": [
"### Сохранение рельефа"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f29aec08-f946-4044-b19d-8718c4673f7c",
"metadata": {},
"outputs": [],
"source": [
"# from PIL import Image\n",
"\n",
"# for idx, heights in enumerate(model.height_maps):\n",
"# m = heights.abs().mean()\n",
"# s = heights.abs().std()\n",
"# m1, m2 = heights.abs().min(), heights.abs().max()\n",
"# ar = heights.abs().cpu().detach().numpy() \n",
"# print(ar.dtype)\n",
"# im = ar\n",
"# im = im - im.min()\n",
"# im = im / im.max()\n",
"# im = im * 255\n",
"# name_im = f\"phasemask_{idx}.png\"\n",
"# name_np = f\"phasemask_{idx}\"\n",
"# result = Image.fromarray(im.astype(np.uint8))\n",
"# result.save(name_im)\n",
"# np.save(name_np, ar)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}