| 
						
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -0,0 +1,700 @@
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				{
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 "cells": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "f0d9d491-74b8-4567-b2ba-79e99ab499ee",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# Пример на данных QuckDraw \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "https://github.com/googlecreativelab/quickdraw-dataset  \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "Используется урезанная версия с 16 классами "
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 1,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "c4455b6a-5ed9-4499-98b4-42e18fa1f6d8",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "name": "stderr",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "stream",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "text": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "  from .autonotebook import tqdm as notebook_tqdm\n"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import torch\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from torch import flip\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from torch.nn import Module\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from torch import nn\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from torch.nn.functional import conv2d\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import torch.nn.functional as F\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from torchvision.transforms.functional import resize, InterpolationMode\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from einops import rearrange\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import matplotlib.pyplot as plt\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from sklearn.metrics import confusion_matrix, accuracy_score\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import numpy as np\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import math\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "import tqdm\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "from pprint import pprint, pformat"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "32211b41-628a-4376-adf3-e93cc58bc2b4",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Служебные функции"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 2,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "9aefa1b1-e4da-4e02-b691-7745f9759e5a",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# from utils import circular_aperture, imshow, pad_zeros, to_class_labels\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "def pad_zeros(input, size):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  h, w = input.shape[-2:]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  th, tw = size\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if len(input.shape) == 2:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg = torch.zeros(size, device=input.device)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if len(input.shape) == 4:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg = torch.zeros(input.shape[:2] + size, device=input.device)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    x, y = int(th/2 - h/2), int(tw/2 - w/2)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  return gg\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "def unpad_zeros(input, size):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  h, w = input.shape[-2:]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  th, tw = size\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  dx,dy = h-th, w-tw\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if len(input.shape) == 2: \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if len(input.shape) == 4:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  return gg\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "def to_class_labels(softmax_distibutions):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    return torch.argmax(softmax_distibutions, dim=1).cpu()"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 3,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "3d755b73-bac0-47ce-a2fb-6ade6428445e",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "def imshow(tensor, figsize=None, title=\"\", **args):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    figsize = figsize if figsize else (13*0.8,5*0.8)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    if type(tensor) is list:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        for idx, el in enumerate(tensor):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            imshow(el, figsize=figsize, title=title, **args)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            plt.suptitle(\"{} {}\".format(idx, title))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    if len(tensor.shape)==4:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        for idx, el in enumerate(torch.squeeze(tensor, dim=1)):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            imshow(el, figsize=figsize, title=title, **args)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            plt.suptitle(\"{} {}\".format(idx, title))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    if tensor.dtype == torch.complex64:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]})\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        real_im = ax[0].imshow(tensor.real, **args)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        imag_im = ax[3].imshow(tensor.imag, **args)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box = ax[1].get_position()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box.x0 = box.x0 - 0.02\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box.x1 = box.x1 - 0.03\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        ax[1].set_position(box)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box = ax[4].get_position()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box.x0 = box.x0 - 0.02\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        box.x1 = box.x1 - 0.03\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        ax[4].set_position(box)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        ax[0].set_title(\"real\");\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        ax[3].set_title(\"imag\");\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f.colorbar(real_im, ax[1]);\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f.colorbar(imag_im, ax[4]);\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f.suptitle(title)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        ax[2].remove()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return f, ax\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    else:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        im = ax[0].imshow(tensor, **args)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f.colorbar(im, ax[1])\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        f.suptitle(title)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return f, ax"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "cb1d1521-bade-47f6-a2e8-1d12c857e2f8",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Параметры данных и модели"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 4,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "b9ce4212-977c-4621-850b-db595cd00ab2",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "name": "stdout",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "stream",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "text": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "{'class_slots': 16,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'classes': 16,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'image_size': 28,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'kernel_size': 28,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'layers': 1,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'metric': 0.001,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'pixel_size_meters': 3.6e-05,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'propagation_distance': 300,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'resolution_scale_factor': 2,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'test_class_instances': 100,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'test_data_path': './assets/quickdraw16_test.npy',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'tile_size_scale_factor': 2,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'train_class_instances': 8000,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'train_data_path': './assets/quickdraw16_train.npy',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 'wavelength': 5.32e-07}\n"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG = type('', (), {})() # object for parameters\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# свойства входных данных\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.classes = 16\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.image_size = 28\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.train_class_instances = 8000\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.test_class_instances = 100\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.train_data_path = './assets/quickdraw16_train.npy'\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.test_data_path = './assets/quickdraw16_test.npy'\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# свойства модели оптической системы\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.kernel_size = 28\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.tile_size_scale_factor = 2\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.resolution_scale_factor = 2 \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.class_slots = 16\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.wavelength = 532e-9\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# CONFIG.refractive_index = 1.5090\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.propagation_distance = 300\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.metric = 1e-3\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.pixel_size_meters = 36e-6\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "CONFIG.layers = 1\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "pprint(CONFIG.__dict__)"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "41470dee-088e-4d24-8a1f-4b18636b4dac",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Обучающие и тестовые данные"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 5,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "bb84c1e5-0201-4815-bf88-d5512364e731",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "data": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "text/plain": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "execution_count": 5,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "execute_result"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_data = rearrange(train_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "test_data = rearrange(test_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_data.shape, test_data.shape"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 6,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "2684aad3-16db-41a9-914b-762151d865a3",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "data": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "text/plain": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "(torch.Size([128000, 16]), torch.Size([1600, 16]))"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "execution_count": 6,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "execute_result"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "train_targets.shape, test_targets.shape"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "037548dd-9834-4c6d-b1dd-ced4d5f4adaa",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Модель системы"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 7,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "f116e1f4-e477-4168-bac6-9a21f2ac4190",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "class OpticalSystem(Module):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def __init__(self,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 layers,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 kernel_size_pixels,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 tile_size_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 resolution_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 class_slots,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 classes,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 wavelength = 532e-9, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 # refractive_index = 1.5090, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 propagation_distance = 300,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 pixel_size_meters = 36e-6,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                 metric = 1e-3\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                ):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \"\"\"\"\"\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        super().__init__()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.layers = layers\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.kernel_size_pixels = kernel_size_pixels\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.tile_size_scale_factor = tile_size_scale_factor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.resolution_scale_factor = resolution_scale_factor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.class_slots = class_slots\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.classes = classes\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.wavelength = wavelength\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # self.refractive_index = refractive_index\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.propagation_distance = propagation_distance\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.pixel_size_meters = pixel_size_meters\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.metric = metric\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        assert(self.class_slots >= self.classes)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.empty_class_slots = self.class_slots - self.classes \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.height_maps = []\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        for i in range(self.layers):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            heights = nn.Parameter(torch.ones([self.phase_mask_size, self.phase_mask_size], dtype=torch.float32))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            torch.nn.init.uniform_(heights, a=0.5*self.wavelength, b=1.5*self.wavelength)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            self.height_maps.append(heights)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.height_maps = torch.nn.ParameterList(self.height_maps)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        B = A*self.phase_mask_size/self.tile_size \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        x = torch.linspace(-B, B, self.phase_mask_size+1)[:-1]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        x, y = torch.meshgrid(x, x, indexing='ij')\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        kx = torch.linspace(-torch.pi*self.phase_mask_size/2/B, torch.pi*self.phase_mask_size/2/B, self.phase_mask_size+1)[:-1]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        Kx, Ky = torch.meshgrid(kx, kx, indexing='ij')\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        vv = torch.arange(0,self.phase_mask_size)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        vv = (-1)**vv\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        a, b = torch.meshgrid(vv, vv, indexing='ij')\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        lambda1 = self.wavelength / self.metric\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.U = nn.Parameter((Kx**2 + Ky**2).float())\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.vv = nn.Parameter((a*b).float())\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.D = nn.Parameter(torch.exp(-1j*(x**2 + y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.U.requires_grad = False\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.vv.requires_grad = False\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.D.requires_grad = True\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        self.coef.requires_grad = False\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def propagation(self, field, propagation_distance):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def opt_conv(self, inputs, heights):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        result = result * self.D\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return amplitude\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def forward(self, image):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \"\"\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        Алгоритм:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        3. Моделируется прохождение света через транспаранты\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется (нужна ли нормировка?)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        6. Вектор максимальных значений преобразутся в распределение вероятностей функцией softmax\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        \"\"\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 1\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        image = resize(\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            image, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            size=(image.shape[-2]*self.resolution_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                  image.shape[-1]*self.resolution_scale_factor),\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            interpolation=InterpolationMode.NEAREST\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        )\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 2\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        image = pad_zeros(\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            image, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            size = (self.phase_mask_size , \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                    self.phase_mask_size ),\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        )\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 3      \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        x = image \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        for i, plate_heights in enumerate(self.height_maps):  \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            x = self.opt_conv(x, plate_heights)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        convolved = x\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 4\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        grid_to_depth = rearrange(\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            convolved,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            \"b 1 (m ht) (n wt) -> b (m n) ht wt\",\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            ht = self.tile_size*self.resolution_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            wt = self.tile_size*self.resolution_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            m = self.tiles_per_dim,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            n = self.tiles_per_dim\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        )\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 5\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        grid_to_depth = unpad_zeros(grid_to_depth, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                                    (self.kernel_size_pixels*self.resolution_scale_factor,  \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                                     self.kernel_size_pixels*self.resolution_scale_factor))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        max_pool = torch.nn.functional.max_pool2d(\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            grid_to_depth,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            kernel_size = self.kernel_size_pixels*self.resolution_scale_factor\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        )               \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        max_pool = rearrange(max_pool, \"b class_slots 1 1 -> b class_slots\", class_slots=self.class_slots)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        max_pool /= max_pool.max(dim=1, keepdims=True).values\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        # 6\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        softmax = torch.nn.functional.softmax(max_pool, dim=1)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return softmax, convolved\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def __repr__(self):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        tmp = {}\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        for k,v in self.__dict__.items():\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            if not k[0] == '_':\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                tmp[k] = v\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        tmp.update(self.__dict__['_modules'])\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        tmp.update({k:f\"{v.dtype} {v.shape}\" for k,v in self.__dict__['_parameters'].items()})\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        return pformat(tmp, indent=2)"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "200b86bb-188c-4034-8a98-1bcfb5b92d9e",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Создание экземпляра модели, оптимизатора, функции потерь"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 8,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "12c87b56-bdcf-4e9c-a171-b096ff244376",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "data": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      "text/plain": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "{ 'D': 'torch.complex64 torch.Size([448, 448])',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'U': 'torch.float32 torch.Size([448, 448])',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'class_slots': 16,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'classes': 16,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'coef': 'torch.complex64 torch.Size([1])',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'empty_class_slots': 0,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'height_maps': ParameterList(  (0): Parameter containing: [torch.cuda.FloatTensor of size 448x448 (GPU 0)]),\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'k': 'torch.float32 torch.Size([1])',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'kernel_size_pixels': 28,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'layers': 1,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'metric': 0.001,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'phase_mask_size': 448,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'pixel_size_meters': 3.6e-05,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'propagation_distance': 300,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'resolution_scale_factor': 2,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'tile_size': 56,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'tile_size_scale_factor': 2,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'tiles_per_dim': 4,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'training': False,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'vv': 'torch.float32 torch.Size([448, 448])',\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				       "  'wavelength': 5.32e-07}"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "execution_count": 8,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "execute_result"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "model = OpticalSystem(\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     layers = CONFIG.layers,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     kernel_size_pixels = CONFIG.kernel_size,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     tile_size_scale_factor = CONFIG.tile_size_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     resolution_scale_factor = CONFIG.resolution_scale_factor,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     class_slots = CONFIG.class_slots,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     classes = CONFIG.classes,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     wavelength = CONFIG.wavelength, \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     propagation_distance = CONFIG.propagation_distance,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     pixel_size_meters = CONFIG.pixel_size_meters,\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "     metric = CONFIG.metric\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ")\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# comment to train from scratch\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# model.load_state_dict(torch.load(CONFIG.phasemask_model_1_path))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "model.eval()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "model.cuda()"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": 9,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "43dcf5cb-a10a-4f02-811a-1c365e2e1bed",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "optimizer = torch.optim.Adam(params=model.cuda().parameters(), \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                             lr=1e-2)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "loss_function = torch.nn.CrossEntropyLoss()"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "d6f489d1-0d5d-43ec-bcaf-f01a82c28bf1",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Обучение"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": null,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "c3851a9d-1867-4c5e-b4ec-3598448a054f",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "name": "stderr",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "output_type": "stream",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     "text": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				      " 97%|█████████▋| 4071/4188 [11:23<00:19,  5.97it/s, loss: 2.736057e+00, acc: 0.69, lr: 1.704265e-04, passes_through_dataset: 24]"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				     ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# training loop\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "batch_size = 764\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "max_passes_through_dataset = 25\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "epochs = int(train_data.shape[0]/batch_size*max_passes_through_dataset)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "ppp = tqdm.trange(epochs)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "def init_batch_generator(train_data, train_labels, batch_size):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \"\"\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    Возвращает функцию, вызов которой возвращает следующие batch_size\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    примеров и им соответствуюющих меток из train_data, train_labels.\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    Примеры выбираются последовательно, по кругу. Массивы с входными \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    примерами и метками классов перемешиваются в начале каждого круга.\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \"\"\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    def f():\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        i = 0\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        rnd_indx = torch.randperm(train_data.shape[0])\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        train_data_shuffled = train_data[rnd_indx]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        train_labels_shuffled = train_labels[rnd_indx]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "        while True:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            if i + batch_size > train_data.shape[0]:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                i = 0\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                rnd_indx = torch.randperm(train_data.shape[0])\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                train_data_shuffled = train_data[rnd_indx]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                train_labels_shuffled = train_labels[rnd_indx]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "                \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            batch_inputs = train_data_shuffled[i:i+batch_size]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            batch_targets = train_labels_shuffled[i:i+batch_size]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            i = i + batch_size\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "            yield batch_inputs, batch_targets\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    return f()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "batch_iterator = init_batch_generator(train_data, train_targets, batch_size)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "i = 0\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "for epoch in ppp:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  batch_inputs, batch_targets = next(batch_iterator)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  batch_inputs = batch_inputs.cuda()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  batch_targets = batch_targets.cuda()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  i = i + batch_size\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  passes_through_dataset = i//train_data.shape[0]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  # apply model\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  predicted, convolved = model(batch_inputs)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  # correct model\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  loss_value = loss_function(predicted, batch_targets)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  loss_value.backward()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  optimizer.step()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  # для небольших батчей следует уменьшать частоту вывода \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if epoch % 2 == 0:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    ppp.set_postfix_str(\"loss: {:e}, acc: {:.2f}, lr: {:e}, passes_through_dataset: {}\".format(loss_value, acc, scheduler.get_last_lr()[0], passes_through_dataset))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "  if (scheduler.get_last_lr()[0] > 1e-13):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    scheduler.step()"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "ae4259bb-199b-43ee-951a-4d481dd1cddc",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Тест"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": null,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "f68545e2-7f65-464e-94ef-cee7ffd0236a",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "inputs = test_data\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "targets = test_targets\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "batch_size = 64\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "predicted = []\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "batch_start = 0\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "while batch_start < test_data.shape[0]:\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    batch_end = min(batch_start + batch_size, test_data.shape[0])\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    batch_input = inputs[batch_start:batch_end].cuda() \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    batch_output, _ = model(batch_input)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    predicted.append(batch_output.detach().cpu())\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    batch_start = batch_end\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "predicted = torch.concat(predicted)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\"Accuracy on test dataset: \", test_acc"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": null,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "296f2dab-6f4a-40b1-81ea-a1e7ed7e72ba",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "imshow(model.height_maps, figsize=(10,10))"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": null,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "9a4cd483-be14-4222-b3c6-0fadd0b1c017",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "class_id = 3\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "image = test_data[test_labels==class_id][:1]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "imshow(image, title=f\"Input images\")\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "softmax, convolved = model(image.cuda())\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "for idx, psf in enumerate(convolved):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    psf = psf.squeeze()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    f, ax = imshow(psf, figsize=(5,5), title=f\"Result of optical convolution with phase plate for image {idx}\")\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    ax[0].hlines(np.arange(0, psf.shape[0], psf.shape[0]//model.tiles_per_dim), 0, psf.shape[1]-1)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    ax[0].vlines(np.arange(0, psf.shape[1], psf.shape[1]//model.tiles_per_dim), 0, psf.shape[0]-1)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    y,x = (psf==torch.max(psf)).nonzero()[0]\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "    ax[0].text(x,y, \"max\", color='white');"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "markdown",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "63f39f87-5a98-454b-9f1c-fdc712b11f0b",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "### Сохранение рельефа"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "cell_type": "code",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "execution_count": null,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "id": "f29aec08-f946-4044-b19d-8718c4673f7c",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "metadata": {},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "outputs": [],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "source": [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# from PIL import Image\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "# for idx, heights in enumerate(model.height_maps):\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     m = heights.abs().mean()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     s = heights.abs().std()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     m1, m2 = heights.abs().min(), heights.abs().max()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     ar = heights.abs().cpu().detach().numpy() \n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     print(ar.dtype)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     im = ar\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     im = im - im.min()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     im = im / im.max()\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     im = im * 255\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     name_im = f\"phasemask_{idx}.png\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     name_np = f\"phasemask_{idx}\"\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     result = Image.fromarray(im.astype(np.uint8))\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     result.save(name_im)\n",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "#     np.save(name_np, ar)"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 "metadata": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  "kernelspec": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "display_name": "Python 3 (ipykernel)",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "language": "python",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "name": "python3"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  "language_info": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "codemirror_mode": {
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "name": "ipython",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    "version": 3
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "file_extension": ".py",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "mimetype": "text/x-python",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "name": "python",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "nbconvert_exporter": "python",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "pygments_lexer": "ipython3",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				   "version": "3.7.13"
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				  }
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 },
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 "nbformat": 4,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				 "nbformat_minor": 5
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				}
 |