|
|
|
@ -0,0 +1,700 @@
|
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "f0d9d491-74b8-4567-b2ba-79e99ab499ee",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Пример на данных QuckDraw \n",
|
|
|
|
|
"https://github.com/googlecreativelab/quickdraw-dataset \n",
|
|
|
|
|
"Используется урезанная версия с 16 классами "
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 1,
|
|
|
|
|
"id": "c4455b6a-5ed9-4499-98b4-42e18fa1f6d8",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
|
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"from torch import flip\n",
|
|
|
|
|
"from torch.nn import Module\n",
|
|
|
|
|
"from torch import nn\n",
|
|
|
|
|
"from torch.nn.functional import conv2d\n",
|
|
|
|
|
"import torch.nn.functional as F\n",
|
|
|
|
|
"from torchvision.transforms.functional import resize, InterpolationMode\n",
|
|
|
|
|
"from einops import rearrange\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"import math\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import tqdm\n",
|
|
|
|
|
"from pprint import pprint, pformat"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "32211b41-628a-4376-adf3-e93cc58bc2b4",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Служебные функции"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 2,
|
|
|
|
|
"id": "9aefa1b1-e4da-4e02-b691-7745f9759e5a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# from utils import circular_aperture, imshow, pad_zeros, to_class_labels\n",
|
|
|
|
|
"def pad_zeros(input, size):\n",
|
|
|
|
|
" h, w = input.shape[-2:]\n",
|
|
|
|
|
" th, tw = size\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if len(input.shape) == 2:\n",
|
|
|
|
|
" gg = torch.zeros(size, device=input.device)\n",
|
|
|
|
|
" x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
|
|
|
|
|
" gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if len(input.shape) == 4:\n",
|
|
|
|
|
" gg = torch.zeros(input.shape[:2] + size, device=input.device)\n",
|
|
|
|
|
" x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
|
|
|
|
|
" gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return gg\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def unpad_zeros(input, size):\n",
|
|
|
|
|
" h, w = input.shape[-2:]\n",
|
|
|
|
|
" th, tw = size\n",
|
|
|
|
|
" dx,dy = h-th, w-tw\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if len(input.shape) == 2: \n",
|
|
|
|
|
" gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if len(input.shape) == 4:\n",
|
|
|
|
|
" gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]\n",
|
|
|
|
|
" return gg\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def to_class_labels(softmax_distibutions):\n",
|
|
|
|
|
" return torch.argmax(softmax_distibutions, dim=1).cpu()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 3,
|
|
|
|
|
"id": "3d755b73-bac0-47ce-a2fb-6ade6428445e",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def imshow(tensor, figsize=None, title=\"\", **args):\n",
|
|
|
|
|
" tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor\n",
|
|
|
|
|
" tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" figsize = figsize if figsize else (13*0.8,5*0.8)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if type(tensor) is list:\n",
|
|
|
|
|
" for idx, el in enumerate(tensor):\n",
|
|
|
|
|
" imshow(el, figsize=figsize, title=title, **args)\n",
|
|
|
|
|
" plt.suptitle(\"{} {}\".format(idx, title))\n",
|
|
|
|
|
" return\n",
|
|
|
|
|
" if len(tensor.shape)==4:\n",
|
|
|
|
|
" for idx, el in enumerate(torch.squeeze(tensor, dim=1)):\n",
|
|
|
|
|
" imshow(el, figsize=figsize, title=title, **args)\n",
|
|
|
|
|
" plt.suptitle(\"{} {}\".format(idx, title))\n",
|
|
|
|
|
" return\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if tensor.dtype == torch.complex64:\n",
|
|
|
|
|
" f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]})\n",
|
|
|
|
|
" real_im = ax[0].imshow(tensor.real, **args)\n",
|
|
|
|
|
" imag_im = ax[3].imshow(tensor.imag, **args)\n",
|
|
|
|
|
" box = ax[1].get_position()\n",
|
|
|
|
|
" box.x0 = box.x0 - 0.02\n",
|
|
|
|
|
" box.x1 = box.x1 - 0.03\n",
|
|
|
|
|
" ax[1].set_position(box)\n",
|
|
|
|
|
" box = ax[4].get_position()\n",
|
|
|
|
|
" box.x0 = box.x0 - 0.02\n",
|
|
|
|
|
" box.x1 = box.x1 - 0.03\n",
|
|
|
|
|
" ax[4].set_position(box)\n",
|
|
|
|
|
" ax[0].set_title(\"real\");\n",
|
|
|
|
|
" ax[3].set_title(\"imag\");\n",
|
|
|
|
|
" f.colorbar(real_im, ax[1]);\n",
|
|
|
|
|
" f.colorbar(imag_im, ax[4]);\n",
|
|
|
|
|
" f.suptitle(title)\n",
|
|
|
|
|
" ax[2].remove()\n",
|
|
|
|
|
" return f, ax\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)\n",
|
|
|
|
|
" im = ax[0].imshow(tensor, **args)\n",
|
|
|
|
|
" f.colorbar(im, ax[1])\n",
|
|
|
|
|
" f.suptitle(title)\n",
|
|
|
|
|
" return f, ax"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "cb1d1521-bade-47f6-a2e8-1d12c857e2f8",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Параметры данных и модели"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 4,
|
|
|
|
|
"id": "b9ce4212-977c-4621-850b-db595cd00ab2",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"{'class_slots': 16,\n",
|
|
|
|
|
" 'classes': 16,\n",
|
|
|
|
|
" 'image_size': 28,\n",
|
|
|
|
|
" 'kernel_size': 28,\n",
|
|
|
|
|
" 'layers': 1,\n",
|
|
|
|
|
" 'metric': 0.001,\n",
|
|
|
|
|
" 'pixel_size_meters': 3.6e-05,\n",
|
|
|
|
|
" 'propagation_distance': 300,\n",
|
|
|
|
|
" 'resolution_scale_factor': 2,\n",
|
|
|
|
|
" 'test_class_instances': 100,\n",
|
|
|
|
|
" 'test_data_path': './assets/quickdraw16_test.npy',\n",
|
|
|
|
|
" 'tile_size_scale_factor': 2,\n",
|
|
|
|
|
" 'train_class_instances': 8000,\n",
|
|
|
|
|
" 'train_data_path': './assets/quickdraw16_train.npy',\n",
|
|
|
|
|
" 'wavelength': 5.32e-07}\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"CONFIG = type('', (), {})() # object for parameters\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# свойства входных данных\n",
|
|
|
|
|
"CONFIG.classes = 16\n",
|
|
|
|
|
"CONFIG.image_size = 28\n",
|
|
|
|
|
"CONFIG.train_class_instances = 8000\n",
|
|
|
|
|
"CONFIG.test_class_instances = 100\n",
|
|
|
|
|
"CONFIG.train_data_path = './assets/quickdraw16_train.npy'\n",
|
|
|
|
|
"CONFIG.test_data_path = './assets/quickdraw16_test.npy'\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# свойства модели оптической системы\n",
|
|
|
|
|
"CONFIG.kernel_size = 28\n",
|
|
|
|
|
"CONFIG.tile_size_scale_factor = 2\n",
|
|
|
|
|
"CONFIG.resolution_scale_factor = 2 \n",
|
|
|
|
|
"CONFIG.class_slots = 16\n",
|
|
|
|
|
"CONFIG.wavelength = 532e-9\n",
|
|
|
|
|
"# CONFIG.refractive_index = 1.5090\n",
|
|
|
|
|
"CONFIG.propagation_distance = 300\n",
|
|
|
|
|
"CONFIG.metric = 1e-3\n",
|
|
|
|
|
"CONFIG.pixel_size_meters = 36e-6\n",
|
|
|
|
|
"CONFIG.layers = 1\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"pprint(CONFIG.__dict__)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "41470dee-088e-4d24-8a1f-4b18636b4dac",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Обучающие и тестовые данные"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 5,
|
|
|
|
|
"id": "bb84c1e5-0201-4815-bf88-d5512364e731",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 5,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)\n",
|
|
|
|
|
"test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)\n",
|
|
|
|
|
"train_data = rearrange(train_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
|
|
|
|
|
"test_data = rearrange(test_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
|
|
|
|
|
"train_data.shape, test_data.shape"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 6,
|
|
|
|
|
"id": "2684aad3-16db-41a9-914b-762151d865a3",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"(torch.Size([128000, 16]), torch.Size([1600, 16]))"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 6,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)\n",
|
|
|
|
|
"test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances)\n",
|
|
|
|
|
"test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_targets.shape, test_targets.shape"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "037548dd-9834-4c6d-b1dd-ced4d5f4adaa",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Модель системы"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 7,
|
|
|
|
|
"id": "f116e1f4-e477-4168-bac6-9a21f2ac4190",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class OpticalSystem(Module):\n",
|
|
|
|
|
" def __init__(self,\n",
|
|
|
|
|
" layers,\n",
|
|
|
|
|
" kernel_size_pixels,\n",
|
|
|
|
|
" tile_size_scale_factor,\n",
|
|
|
|
|
" resolution_scale_factor,\n",
|
|
|
|
|
" class_slots,\n",
|
|
|
|
|
" classes,\n",
|
|
|
|
|
" wavelength = 532e-9, \n",
|
|
|
|
|
" # refractive_index = 1.5090, \n",
|
|
|
|
|
" propagation_distance = 300,\n",
|
|
|
|
|
" pixel_size_meters = 36e-6,\n",
|
|
|
|
|
" metric = 1e-3\n",
|
|
|
|
|
" ):\n",
|
|
|
|
|
" \"\"\"\"\"\"\n",
|
|
|
|
|
" super().__init__()\n",
|
|
|
|
|
" self.layers = layers\n",
|
|
|
|
|
" self.kernel_size_pixels = kernel_size_pixels\n",
|
|
|
|
|
" self.tile_size_scale_factor = tile_size_scale_factor\n",
|
|
|
|
|
" self.resolution_scale_factor = resolution_scale_factor\n",
|
|
|
|
|
" self.class_slots = class_slots\n",
|
|
|
|
|
" self.classes = classes\n",
|
|
|
|
|
" self.wavelength = wavelength\n",
|
|
|
|
|
" # self.refractive_index = refractive_index\n",
|
|
|
|
|
" self.propagation_distance = propagation_distance\n",
|
|
|
|
|
" self.pixel_size_meters = pixel_size_meters\n",
|
|
|
|
|
" self.metric = metric\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" assert(self.class_slots >= self.classes)\n",
|
|
|
|
|
" self.empty_class_slots = self.class_slots - self.classes \n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor\n",
|
|
|
|
|
" self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)\n",
|
|
|
|
|
" self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.height_maps = []\n",
|
|
|
|
|
" for i in range(self.layers):\n",
|
|
|
|
|
" heights = nn.Parameter(torch.ones([self.phase_mask_size, self.phase_mask_size], dtype=torch.float32))\n",
|
|
|
|
|
" torch.nn.init.uniform_(heights, a=0.5*self.wavelength, b=1.5*self.wavelength)\n",
|
|
|
|
|
" self.height_maps.append(heights)\n",
|
|
|
|
|
" self.height_maps = torch.nn.ParameterList(self.height_maps)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric\n",
|
|
|
|
|
" B = A*self.phase_mask_size/self.tile_size \n",
|
|
|
|
|
" x = torch.linspace(-B, B, self.phase_mask_size+1)[:-1]\n",
|
|
|
|
|
" x, y = torch.meshgrid(x, x, indexing='ij')\n",
|
|
|
|
|
" kx = torch.linspace(-torch.pi*self.phase_mask_size/2/B, torch.pi*self.phase_mask_size/2/B, self.phase_mask_size+1)[:-1]\n",
|
|
|
|
|
" Kx, Ky = torch.meshgrid(kx, kx, indexing='ij')\n",
|
|
|
|
|
" vv = torch.arange(0,self.phase_mask_size)\n",
|
|
|
|
|
" vv = (-1)**vv\n",
|
|
|
|
|
" a, b = torch.meshgrid(vv, vv, indexing='ij')\n",
|
|
|
|
|
" lambda1 = self.wavelength / self.metric\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.U = nn.Parameter((Kx**2 + Ky**2).float())\n",
|
|
|
|
|
" self.vv = nn.Parameter((a*b).float())\n",
|
|
|
|
|
" self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))\n",
|
|
|
|
|
" self.D = nn.Parameter(torch.exp(-1j*(x**2 + y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))\n",
|
|
|
|
|
" self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))\n",
|
|
|
|
|
" self.U.requires_grad = False\n",
|
|
|
|
|
" self.vv.requires_grad = False\n",
|
|
|
|
|
" self.D.requires_grad = True\n",
|
|
|
|
|
" self.coef.requires_grad = False\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def propagation(self, field, propagation_distance):\n",
|
|
|
|
|
" F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)\n",
|
|
|
|
|
" return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def opt_conv(self, inputs, heights):\n",
|
|
|
|
|
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
|
|
|
|
|
" result = result * self.D\n",
|
|
|
|
|
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
|
|
|
|
|
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
|
|
|
|
|
" return amplitude\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def forward(self, image):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Алгоритм:\n",
|
|
|
|
|
" 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]\n",
|
|
|
|
|
" 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]\n",
|
|
|
|
|
" 3. Моделируется прохождение света через транспаранты\n",
|
|
|
|
|
" 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim\n",
|
|
|
|
|
" 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется (нужна ли нормировка?)\n",
|
|
|
|
|
" 6. Вектор максимальных значений преобразутся в распределение вероятностей функцией softmax\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" # 1\n",
|
|
|
|
|
" image = resize(\n",
|
|
|
|
|
" image, \n",
|
|
|
|
|
" size=(image.shape[-2]*self.resolution_scale_factor,\n",
|
|
|
|
|
" image.shape[-1]*self.resolution_scale_factor),\n",
|
|
|
|
|
" interpolation=InterpolationMode.NEAREST\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" # 2\n",
|
|
|
|
|
" image = pad_zeros(\n",
|
|
|
|
|
" image, \n",
|
|
|
|
|
" size = (self.phase_mask_size , \n",
|
|
|
|
|
" self.phase_mask_size ),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" # 3 \n",
|
|
|
|
|
" x = image \n",
|
|
|
|
|
" for i, plate_heights in enumerate(self.height_maps): \n",
|
|
|
|
|
" x = self.opt_conv(x, plate_heights)\n",
|
|
|
|
|
" convolved = x\n",
|
|
|
|
|
" # 4\n",
|
|
|
|
|
" grid_to_depth = rearrange(\n",
|
|
|
|
|
" convolved,\n",
|
|
|
|
|
" \"b 1 (m ht) (n wt) -> b (m n) ht wt\",\n",
|
|
|
|
|
" ht = self.tile_size*self.resolution_scale_factor,\n",
|
|
|
|
|
" wt = self.tile_size*self.resolution_scale_factor,\n",
|
|
|
|
|
" m = self.tiles_per_dim,\n",
|
|
|
|
|
" n = self.tiles_per_dim\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" # 5\n",
|
|
|
|
|
" grid_to_depth = unpad_zeros(grid_to_depth, \n",
|
|
|
|
|
" (self.kernel_size_pixels*self.resolution_scale_factor, \n",
|
|
|
|
|
" self.kernel_size_pixels*self.resolution_scale_factor))\n",
|
|
|
|
|
" max_pool = torch.nn.functional.max_pool2d(\n",
|
|
|
|
|
" grid_to_depth,\n",
|
|
|
|
|
" kernel_size = self.kernel_size_pixels*self.resolution_scale_factor\n",
|
|
|
|
|
" ) \n",
|
|
|
|
|
" max_pool = rearrange(max_pool, \"b class_slots 1 1 -> b class_slots\", class_slots=self.class_slots)\n",
|
|
|
|
|
" max_pool /= max_pool.max(dim=1, keepdims=True).values\n",
|
|
|
|
|
" # 6\n",
|
|
|
|
|
" softmax = torch.nn.functional.softmax(max_pool, dim=1)\n",
|
|
|
|
|
" return softmax, convolved\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def __repr__(self):\n",
|
|
|
|
|
" tmp = {}\n",
|
|
|
|
|
" for k,v in self.__dict__.items():\n",
|
|
|
|
|
" if not k[0] == '_':\n",
|
|
|
|
|
" tmp[k] = v\n",
|
|
|
|
|
" tmp.update(self.__dict__['_modules'])\n",
|
|
|
|
|
" tmp.update({k:f\"{v.dtype} {v.shape}\" for k,v in self.__dict__['_parameters'].items()})\n",
|
|
|
|
|
" return pformat(tmp, indent=2)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "200b86bb-188c-4034-8a98-1bcfb5b92d9e",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Создание экземпляра модели, оптимизатора, функции потерь"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 8,
|
|
|
|
|
"id": "12c87b56-bdcf-4e9c-a171-b096ff244376",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"{ 'D': 'torch.complex64 torch.Size([448, 448])',\n",
|
|
|
|
|
" 'U': 'torch.float32 torch.Size([448, 448])',\n",
|
|
|
|
|
" 'class_slots': 16,\n",
|
|
|
|
|
" 'classes': 16,\n",
|
|
|
|
|
" 'coef': 'torch.complex64 torch.Size([1])',\n",
|
|
|
|
|
" 'empty_class_slots': 0,\n",
|
|
|
|
|
" 'height_maps': ParameterList( (0): Parameter containing: [torch.cuda.FloatTensor of size 448x448 (GPU 0)]),\n",
|
|
|
|
|
" 'k': 'torch.float32 torch.Size([1])',\n",
|
|
|
|
|
" 'kernel_size_pixels': 28,\n",
|
|
|
|
|
" 'layers': 1,\n",
|
|
|
|
|
" 'metric': 0.001,\n",
|
|
|
|
|
" 'phase_mask_size': 448,\n",
|
|
|
|
|
" 'pixel_size_meters': 3.6e-05,\n",
|
|
|
|
|
" 'propagation_distance': 300,\n",
|
|
|
|
|
" 'resolution_scale_factor': 2,\n",
|
|
|
|
|
" 'tile_size': 56,\n",
|
|
|
|
|
" 'tile_size_scale_factor': 2,\n",
|
|
|
|
|
" 'tiles_per_dim': 4,\n",
|
|
|
|
|
" 'training': False,\n",
|
|
|
|
|
" 'vv': 'torch.float32 torch.Size([448, 448])',\n",
|
|
|
|
|
" 'wavelength': 5.32e-07}"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 8,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"model = OpticalSystem(\n",
|
|
|
|
|
" layers = CONFIG.layers,\n",
|
|
|
|
|
" kernel_size_pixels = CONFIG.kernel_size,\n",
|
|
|
|
|
" tile_size_scale_factor = CONFIG.tile_size_scale_factor,\n",
|
|
|
|
|
" resolution_scale_factor = CONFIG.resolution_scale_factor,\n",
|
|
|
|
|
" class_slots = CONFIG.class_slots,\n",
|
|
|
|
|
" classes = CONFIG.classes,\n",
|
|
|
|
|
" wavelength = CONFIG.wavelength, \n",
|
|
|
|
|
" propagation_distance = CONFIG.propagation_distance,\n",
|
|
|
|
|
" pixel_size_meters = CONFIG.pixel_size_meters,\n",
|
|
|
|
|
" metric = CONFIG.metric\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"# comment to train from scratch\n",
|
|
|
|
|
"# model.load_state_dict(torch.load(CONFIG.phasemask_model_1_path))\n",
|
|
|
|
|
"model.eval()\n",
|
|
|
|
|
"model.cuda()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 9,
|
|
|
|
|
"id": "43dcf5cb-a10a-4f02-811a-1c365e2e1bed",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"optimizer = torch.optim.Adam(params=model.cuda().parameters(), \n",
|
|
|
|
|
" lr=1e-2)\n",
|
|
|
|
|
"scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n",
|
|
|
|
|
"loss_function = torch.nn.CrossEntropyLoss()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d6f489d1-0d5d-43ec-bcaf-f01a82c28bf1",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Обучение"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "c3851a9d-1867-4c5e-b4ec-3598448a054f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
" 97%|█████████▋| 4071/4188 [11:23<00:19, 5.97it/s, loss: 2.736057e+00, acc: 0.69, lr: 1.704265e-04, passes_through_dataset: 24]"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# training loop\n",
|
|
|
|
|
"batch_size = 764\n",
|
|
|
|
|
"max_passes_through_dataset = 25\n",
|
|
|
|
|
"epochs = int(train_data.shape[0]/batch_size*max_passes_through_dataset)\n",
|
|
|
|
|
"ppp = tqdm.trange(epochs)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def init_batch_generator(train_data, train_labels, batch_size):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Возвращает функцию, вызов которой возвращает следующие batch_size\n",
|
|
|
|
|
" примеров и им соответствуюющих меток из train_data, train_labels.\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" Примеры выбираются последовательно, по кругу. Массивы с входными \n",
|
|
|
|
|
" примерами и метками классов перемешиваются в начале каждого круга.\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" def f():\n",
|
|
|
|
|
" i = 0\n",
|
|
|
|
|
" rnd_indx = torch.randperm(train_data.shape[0])\n",
|
|
|
|
|
" train_data_shuffled = train_data[rnd_indx]\n",
|
|
|
|
|
" train_labels_shuffled = train_labels[rnd_indx]\n",
|
|
|
|
|
" while True:\n",
|
|
|
|
|
" if i + batch_size > train_data.shape[0]:\n",
|
|
|
|
|
" i = 0\n",
|
|
|
|
|
" rnd_indx = torch.randperm(train_data.shape[0])\n",
|
|
|
|
|
" train_data_shuffled = train_data[rnd_indx]\n",
|
|
|
|
|
" train_labels_shuffled = train_labels[rnd_indx]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" batch_inputs = train_data_shuffled[i:i+batch_size]\n",
|
|
|
|
|
" batch_targets = train_labels_shuffled[i:i+batch_size]\n",
|
|
|
|
|
" i = i + batch_size\n",
|
|
|
|
|
" yield batch_inputs, batch_targets\n",
|
|
|
|
|
" return f()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"batch_iterator = init_batch_generator(train_data, train_targets, batch_size)\n",
|
|
|
|
|
"i = 0\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for epoch in ppp:\n",
|
|
|
|
|
" batch_inputs, batch_targets = next(batch_iterator)\n",
|
|
|
|
|
" batch_inputs = batch_inputs.cuda()\n",
|
|
|
|
|
" batch_targets = batch_targets.cuda()\n",
|
|
|
|
|
" i = i + batch_size\n",
|
|
|
|
|
" passes_through_dataset = i//train_data.shape[0]\n",
|
|
|
|
|
" # apply model\n",
|
|
|
|
|
" predicted, convolved = model(batch_inputs)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # correct model\n",
|
|
|
|
|
" loss_value = loss_function(predicted, batch_targets)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" loss_value.backward()\n",
|
|
|
|
|
" optimizer.step()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # для небольших батчей следует уменьшать частоту вывода \n",
|
|
|
|
|
" if epoch % 2 == 0:\n",
|
|
|
|
|
" acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted))\n",
|
|
|
|
|
" ppp.set_postfix_str(\"loss: {:e}, acc: {:.2f}, lr: {:e}, passes_through_dataset: {}\".format(loss_value, acc, scheduler.get_last_lr()[0], passes_through_dataset))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if (scheduler.get_last_lr()[0] > 1e-13):\n",
|
|
|
|
|
" scheduler.step()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "ae4259bb-199b-43ee-951a-4d481dd1cddc",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Тест"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "f68545e2-7f65-464e-94ef-cee7ffd0236a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"inputs = test_data\n",
|
|
|
|
|
"targets = test_targets\n",
|
|
|
|
|
"batch_size = 64\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"predicted = []\n",
|
|
|
|
|
"batch_start = 0\n",
|
|
|
|
|
"while batch_start < test_data.shape[0]:\n",
|
|
|
|
|
" batch_end = min(batch_start + batch_size, test_data.shape[0])\n",
|
|
|
|
|
" batch_input = inputs[batch_start:batch_end].cuda() \n",
|
|
|
|
|
" batch_output, _ = model(batch_input)\n",
|
|
|
|
|
" predicted.append(batch_output.detach().cpu())\n",
|
|
|
|
|
" batch_start = batch_end\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"predicted = torch.concat(predicted)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
|
|
|
|
|
"\"Accuracy on test dataset: \", test_acc"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "296f2dab-6f4a-40b1-81ea-a1e7ed7e72ba",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"imshow(model.height_maps, figsize=(10,10))"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "9a4cd483-be14-4222-b3c6-0fadd0b1c017",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class_id = 3\n",
|
|
|
|
|
"image = test_data[test_labels==class_id][:1]\n",
|
|
|
|
|
"imshow(image, title=f\"Input images\")\n",
|
|
|
|
|
"softmax, convolved = model(image.cuda())\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for idx, psf in enumerate(convolved):\n",
|
|
|
|
|
" psf = psf.squeeze()\n",
|
|
|
|
|
" f, ax = imshow(psf, figsize=(5,5), title=f\"Result of optical convolution with phase plate for image {idx}\")\n",
|
|
|
|
|
" ax[0].hlines(np.arange(0, psf.shape[0], psf.shape[0]//model.tiles_per_dim), 0, psf.shape[1]-1)\n",
|
|
|
|
|
" ax[0].vlines(np.arange(0, psf.shape[1], psf.shape[1]//model.tiles_per_dim), 0, psf.shape[0]-1)\n",
|
|
|
|
|
" y,x = (psf==torch.max(psf)).nonzero()[0]\n",
|
|
|
|
|
" ax[0].text(x,y, \"max\", color='white');"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "63f39f87-5a98-454b-9f1c-fdc712b11f0b",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### Сохранение рельефа"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "f29aec08-f946-4044-b19d-8718c4673f7c",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# from PIL import Image\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# for idx, heights in enumerate(model.height_maps):\n",
|
|
|
|
|
"# m = heights.abs().mean()\n",
|
|
|
|
|
"# s = heights.abs().std()\n",
|
|
|
|
|
"# m1, m2 = heights.abs().min(), heights.abs().max()\n",
|
|
|
|
|
"# ar = heights.abs().cpu().detach().numpy() \n",
|
|
|
|
|
"# print(ar.dtype)\n",
|
|
|
|
|
"# im = ar\n",
|
|
|
|
|
"# im = im - im.min()\n",
|
|
|
|
|
"# im = im / im.max()\n",
|
|
|
|
|
"# im = im * 255\n",
|
|
|
|
|
"# name_im = f\"phasemask_{idx}.png\"\n",
|
|
|
|
|
"# name_np = f\"phasemask_{idx}\"\n",
|
|
|
|
|
"# result = Image.fromarray(im.astype(np.uint8))\n",
|
|
|
|
|
"# result.save(name_im)\n",
|
|
|
|
|
"# np.save(name_np, ar)"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3 (ipykernel)",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.7.13"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
}
|