You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
655 lines
213 KiB
Plaintext
655 lines
213 KiB
Plaintext
4 months ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "58f00965-3512-486d-b5ab-7111a17d7c85",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from utils import to_class_labels, imshow\n",
|
||
|
"import torch\n",
|
||
|
"from einops import rearrange\n",
|
||
|
"import numpy as np\n",
|
||
|
"import math\n",
|
||
|
"from pprint import pprint, pformat\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
|
||
|
"import models\n",
|
||
|
"from pathlib import Path\n",
|
||
|
"import inspect"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "674b3e45-d2b9-4b61-88c4-a7a516d6789d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"{'batch_size': 152,\n",
|
||
|
" 'class_slots': 16,\n",
|
||
|
" 'classes': 16,\n",
|
||
|
" 'dataset_name': 'quickdraw',\n",
|
||
|
" 'experiment_dir': PosixPath('experiments/OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05'),\n",
|
||
|
" 'image_size': 28,\n",
|
||
|
" 'kernel_size_pixels': 28,\n",
|
||
|
" 'layers': 1,\n",
|
||
|
" 'loss_plot_path': PosixPath('experiments/OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05/OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05_loss.png'),\n",
|
||
|
" 'max_passes_through_dataset': 20,\n",
|
||
|
" 'metric': 0.001,\n",
|
||
|
" 'mlp_layers': 2,\n",
|
||
|
" 'model_class': <class 'models.OpticalSystem'>,\n",
|
||
|
" 'model_path': PosixPath('experiments/OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05/OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05.pt'),\n",
|
||
|
" 'name_id': 'OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05',\n",
|
||
|
" 'pixel_size_meters': 3.6e-05,\n",
|
||
|
" 'propagation_distance': 300,\n",
|
||
|
" 'resolution_scale_factor': 2,\n",
|
||
|
" 'test_batch_size': 64,\n",
|
||
|
" 'test_class_instances': 100,\n",
|
||
|
" 'test_data_path': './assets/quickdraw16_test.npy',\n",
|
||
|
" 'tile_size_scale_factor': 2,\n",
|
||
|
" 'train_class_instances': 8000,\n",
|
||
|
" 'train_data_path': './assets/quickdraw16_train.npy',\n",
|
||
|
" 'wavelength': 5.32e-07}\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"CONFIG = type('', (), {})() # object for parameters\n",
|
||
|
"\n",
|
||
|
"CONFIG.model_class = models.OpticalSystem\n",
|
||
|
"CONFIG.layers = 1\n",
|
||
|
"CONFIG.mlp_layers = 2\n",
|
||
|
"CONFIG.batch_size = 764//5\n",
|
||
|
"CONFIG.max_passes_through_dataset = 20\n",
|
||
|
"CONFIG.test_batch_size = 64\n",
|
||
|
"\n",
|
||
|
"# свойства входных данных\n",
|
||
|
"CONFIG.dataset_name = 'quickdraw'\n",
|
||
|
"CONFIG.classes = 16\n",
|
||
|
"CONFIG.image_size = 28\n",
|
||
|
"CONFIG.train_class_instances = 8000\n",
|
||
|
"CONFIG.test_class_instances = 100\n",
|
||
|
"CONFIG.train_data_path = './assets/quickdraw16_train.npy'\n",
|
||
|
"CONFIG.test_data_path = './assets/quickdraw16_test.npy'\n",
|
||
|
"\n",
|
||
|
"# свойства модели оптической системы\n",
|
||
|
"CONFIG.kernel_size_pixels = 28\n",
|
||
|
"CONFIG.tile_size_scale_factor = 2\n",
|
||
|
"CONFIG.resolution_scale_factor = 2 \n",
|
||
|
"CONFIG.class_slots = 16\n",
|
||
|
"CONFIG.wavelength = 532e-9\n",
|
||
|
"# CONFIG.refractive_index = 1.5090\n",
|
||
|
"CONFIG.propagation_distance = 300\n",
|
||
|
"CONFIG.metric = 1e-3\n",
|
||
|
"CONFIG.pixel_size_meters = 36e-6\n",
|
||
|
"\n",
|
||
|
"CONFIG.name_id = f\"{CONFIG.model_class.__name__}_{CONFIG.dataset_name}_{CONFIG.layers}_{CONFIG.mlp_layers}_{CONFIG.classes}_{CONFIG.max_passes_through_dataset}_{CONFIG.kernel_size_pixels}\" + \\\n",
|
||
|
" f\"_{CONFIG.resolution_scale_factor}_{CONFIG.tile_size_scale_factor}_{CONFIG.propagation_distance}_{CONFIG.wavelength}_{CONFIG.metric}_{CONFIG.pixel_size_meters}\"\n",
|
||
|
"CONFIG.experiment_dir = Path(f'./experiments/{CONFIG.name_id}/')\n",
|
||
|
"CONFIG.experiment_dir.mkdir(parents=True, exist_ok=True)\n",
|
||
|
"CONFIG.model_path = CONFIG.experiment_dir / f\"{CONFIG.name_id}.pt\"\n",
|
||
|
"CONFIG.loss_plot_path = CONFIG.experiment_dir / f\"{CONFIG.name_id}_loss.png\"\n",
|
||
|
"\n",
|
||
|
"def init_from_config(config):\n",
|
||
|
" init_arg_names = list(inspect.signature(CONFIG.model_class).parameters.keys())\n",
|
||
|
" return CONFIG.model_class(**{k:CONFIG.__dict__[k] for k in init_arg_names})\n",
|
||
|
"\n",
|
||
|
"pprint(CONFIG.__dict__)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "f468894b-ffa6-45a3-872c-95a57e4e263a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(torch.Size([128000, 16]), torch.Size([1600, 16]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)\n",
|
||
|
"test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)\n",
|
||
|
"train_data = rearrange(train_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
|
||
|
"test_data = rearrange(test_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
|
||
|
"train_data.shape, test_data.shape\n",
|
||
|
"\n",
|
||
|
"train_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)\n",
|
||
|
"test_targets = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)\n",
|
||
|
"\n",
|
||
|
"train_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.train_class_instances)\n",
|
||
|
"test_labels = torch.repeat_interleave(torch.arange(CONFIG.classes), CONFIG.test_class_instances)\n",
|
||
|
"\n",
|
||
|
"train_targets.shape, test_targets.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "d79ff533-d7ee-4c37-866d-f3cc7211648d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"inputs = test_data\n",
|
||
|
"targets = test_targets"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "c4695c40-a01b-4ce6-b14c-9edc994a60ee",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model = init_from_config(CONFIG)\n",
|
||
|
"# comment to train from scratch\n",
|
||
|
"model.load_state_dict(torch.load(CONFIG.model_path))\n",
|
||
|
"model.eval()\n",
|
||
|
"model.cuda();"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7d6820d6-599c-4142-a879-f1b943549698",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### amplitude noise"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "91941235-bc74-4f47-a3fa-47221eddda65",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import types\n",
|
||
|
"\n",
|
||
|
"def to_class_labels(softmax_distibutions):\n",
|
||
|
" return torch.argmax(softmax_distibutions, dim=1).cpu()\n",
|
||
|
"\n",
|
||
|
"def test(inputs, targets, noise_percent=0.1):\n",
|
||
|
" def opt_conv(self, inputs, heights):\n",
|
||
|
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
|
||
|
" result = result * heights \n",
|
||
|
" result.real += torch.randn_like(result.real)*(result.real.abs().max()*noise_percent)\n",
|
||
|
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
|
||
|
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
|
||
|
" return amplitude\n",
|
||
|
"\n",
|
||
|
" model.opt_conv = types.MethodType(opt_conv, model) \n",
|
||
|
" \n",
|
||
|
" predicted = []\n",
|
||
|
" batch_start = 0\n",
|
||
|
" while batch_start < test_data.shape[0]:\n",
|
||
|
" batch_end = min(batch_start + CONFIG.test_batch_size, test_data.shape[0])\n",
|
||
|
" batch_input = inputs[batch_start:batch_end].cuda() \n",
|
||
|
" with torch.inference_mode():\n",
|
||
|
" batch_output, _ = model(batch_input)\n",
|
||
|
" predicted.append(batch_output.cpu())\n",
|
||
|
" batch_start = batch_end\n",
|
||
|
" \n",
|
||
|
" predicted = torch.concat(predicted)\n",
|
||
|
" \n",
|
||
|
" test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
|
||
|
" return test_acc"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "e8b4f58c-4c01-4b94-a794-7acdb2e82721",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"x = np.arange(0, 1.0, 0.01)\n",
|
||
|
"y_amp = []\n",
|
||
|
"for p in x:\n",
|
||
|
" y_amp.append(np.mean([test(inputs, targets, noise_percent=p) for _ in range(5)]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "f4eaeb67-4b9a-4ea0-bbd9-d6f8ca1a4f21",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABsZUlEQVR4nO3deVwU9f8H8Nfuwi73DcshgoCKeGGopOaRopSm2WmXGqWVSll08utQ6/vNNLPT0q95dJiapnaZSiiVJ955IAqCeHDKDQrL7uf3B7G5Asqx7MDyej4ePL7fnZ2Zfc/boX0x85kZmRBCgIiIiMhMyKUugIiIiMiYGG6IiIjIrDDcEBERkVlhuCEiIiKzwnBDREREZoXhhoiIiMwKww0RERGZFYYbIiIiMisMN0RERGRWGG6IGiEhIQEymQwJCQn6aY8//jj8/f1N8vnp6emQyWRYuXKlST6vIWQyGWbPnn3T+WbPng2ZTGYwzd/fH48//niDPmfYsGEYNmxY4wskPXPqoSl/76jtYbihVu/zzz+HTCZDeHi41KU0SHl5OWbPnm0QgKhhTp48idmzZyM9PV3qUmr5+eefMXbsWKjVaiiVSri4uGDIkCH44IMPUFxcLHV5RHQNC6kLILqZVatWwd/fH4mJiUhJSUFQUJDUJRlYunQpdDqd/nV5eTnmzJkDAGbzV3JLSU5Ohlz+799YJ0+exJw5czBs2LBaf5Vv27bNxNVV0+l0ePLJJ7Fy5Ur07NkT06dPh6+vL0pKSrBnzx688cYb2Lx5M+Lj4yWprzGk6mFLuP73juhaDDfUqqWlpWH37t3YsGEDnn76aaxatQqzZs2SuiwDlpaWUpfQZqlUqgbPq1QqW7CS+s2fPx8rV67ECy+8gA8++MDg1NrMmTORmZmJr7/+WpLaGkuqHrYE/t7RjfC0FLVqq1atgrOzM8aMGYP7778fq1atqjVPzTiUBQsWYNGiRQgICICNjQ1GjRqF8+fPQwiBd955Bx06dIC1tTXuvvtu5OfnG6zD398fd911F7Zt24bQ0FBYWVkhJCQEGzZsuGmN1577T09Ph7u7OwBgzpw5kMlkBmNS6hvzUNf4gcLCQjz++ONwdHSEk5MTJk+ejMLCwjprOHXqFO6//364uLjAysoKffv2xU8//XTT2gFgwYIFGDhwIFxdXWFtbY2wsDCsX7++1nwVFRV44YUX4O7uDnt7e4wbNw4XLlyoc507d+5Ev379YGVlhcDAQCxZsqTO+a4dc7Ny5Uo88MADAIDbb79d37ua03vX9i47OxsWFhb6I2TXSk5Ohkwmw2effaafVlhYiOeffx6+vr5QqVQICgrCvHnzbvqXf3l5OebNm4fu3bvj/fffrzVmCAC8vLzw6quvGkxbsWIFhg8fDg8PD6hUKoSEhOCLL76otWx945WuH4uk0WgwZ84cdO7cGVZWVnB1dcVtt92GuLg4/TxZWVmIiopChw4doFKp4OXlhbvvvtvgFN/1+19lZSXeeusthIWFwdHREba2thg8eDB27NhhUM+1v2P/+9//EBgYCJVKhX79+mH//v037CFQ/W8rk8mwa9cuxMTEwN3dHba2trjnnnuQm5tba/7PP/8c3bt3h0qlgre3N2bMmFFr36/rd2bNmjUICwuDvb09HBwc0LNnT3z88ccG8zR1X6C2hUduqFVbtWoV7r33XiiVSjz88MP44osvsH//fvTr16/OeSsrK/Hss88iPz8f8+fPx4MPPojhw4cjISEBr776KlJSUvDpp5/ipZdewvLlyw2WP3PmDCZMmIBnnnkGkydPxooVK/DAAw9gy5YtGDlyZIPqdXd3xxdffIFp06bhnnvuwb333gsA6NWrV6O2WwiBu+++Gzt37sQzzzyDbt26YePGjZg8eXKteU+cOIFBgwbBx8cHr732GmxtbfH9999j/Pjx+OGHH3DPPffc8LM+/vhjjBs3Do8++igqKyuxZs0aPPDAA/jll18wZswY/XxTpkzBt99+i0ceeQQDBw7E9u3bDd6vcezYMYwaNQru7u6YPXs2qqqqMGvWLKjV6hvWMWTIEDz33HP45JNP8H//93/o1q0bAOj/91pqtRpDhw7F999/X+tI3tq1a6FQKPRBqby8HEOHDsXFixfx9NNPo2PHjti9ezdiY2ORmZmJjz76qN6adu7cicLCQrz00ktQKBQ3rP9aX3zxBbp3745x48bBwsICP//8M6ZPnw6dTocZM2Y0eD01Zs+ejblz52LKlCno378/iouLceDAARw6dEi/b9533304ceIEnn32Wfj7+yMnJwdxcXHIyMiod+BtcXExvvzySzz88MOYOnUqSkpKsGzZMkRGRiIxMRGhoaEG83/33XcoKSnB008/DZlMhvnz5+Pee+/F2bNnG3Qk5dlnn4WzszNmzZqF9PR0fPTRR4iOjsbatWsNtnXOnDmIiIjAtGnTkJycrP+937VrV72fExcXh4cffhgjRozAvHnzAABJSUnYtWsXZs6cCaB5+wK1MYKolTpw4IAAIOLi4oQQQuh0OtGhQwcxc+ZMg/nS0tIEAOHu7i4KCwv102NjYwUA0bt3b6HRaPTTH374YaFUKsXVq1f10/z8/AQA8cMPP+inFRUVCS8vL9GnTx/9tB07dggAYseOHfppkydPFn5+fvrXubm5AoCYNWtWrW0aOnSoGDp0aK3p169j06ZNAoCYP3++flpVVZUYPHiwACBWrFihnz5ixAjRs2dPg+3R6XRi4MCBonPnzrU+63rl5eUGrysrK0WPHj3E8OHD9dOOHDkiAIjp06cbzPvII4/U2tbx48cLKysrce7cOf20kydPCoVCIa7/T46fn5+YPHmy/vW6detq9bfG9b1bsmSJACCOHTtmMF9ISIhB7e+8846wtbUVp0+fNpjvtddeEwqFQmRkZNT6rBoff/yxACA2bdpkML2qqkrk5uYa/Oh0Ov371/dUCCEiIyNFQECAwbT69pPr+9K7d28xZsyYeussKCgQAMT7779f7zxC1O5hVVWVqKioqLUutVotnnjiCf20mt8xV1dXkZ+fr5/+448/CgDi559/vuHnrlixQgAQERERBn164YUXhEKh0P/e5uTkCKVSKUaNGiW0Wq1+vs8++0wAEMuXL9dPu/53ZubMmcLBwUFUVVXVW0dz9gVqW3hailqtVatWQa1W4/bbbwdQfQh/woQJWLNmDbRaba35H3jgATg6Oupf11xd9dhjj8HCwsJgemVlJS5evGiwvLe3t8FRDgcHB0yaNAmHDx9GVlaWUbftZjZv3gwLCwtMmzZNP02hUODZZ581mC8/Px/bt2/Hgw8+iJKSEuTl5SEvLw+XL19GZGQkzpw5U2s7r2dtba3//wUFBSgqKsLgwYNx6NAhg3oA4LnnnjNY9vnnnzd4rdVqsXXrVowfPx4dO3bUT+/WrRsiIyMbtvENdO+998LCwsLgr/7jx4/j5MmTmDBhgn7aunXrMHjwYDg7O+v7k5eXh4iICGi1Wvz555/1fkbNVVB2dnYG048dOwZ3d3eDn8uXL+vfv7anRUVFyMvLw9ChQ3H27FkUFRU1eludnJxw4sQJnDlzps73ra2toVQqkZCQgIKCggavV6FQ6Mfh6HQ65Ofno6qqCn379jX4968xYcIEODs7618PHjwYAHD27NkGfd5TTz1lcGpv8ODB0Gq1OHfuHADg999/R2VlJZ5//nmDgeZTp06Fg4MDfv3113rX7eTkhLKyMoNTdddrzr5AbQvDDbVKWq0Wa9aswe233460tDSkpKQgJSUF4eHhyM7OrvPKlGu/TAHog46vr2+d06//EggKCqo1pqJLly4AYPJLk8+dOwcvL69aX6pdu3Y1eJ2SkgIhBN58881aX7Y1p2tycnJu+Fm//PILbr31VlhZWcHFxUV/au3aL+Fz585BLpcjMDDwhvXk5ubiypUr6Ny5c63PuX7e5nJzc8OIESPw/fff66etXbsWFhYW+tOBQPXpxi1bttTqT0REBIAb98fe3h4AUFpaajA9KCgIcXFxiIuLw8SJE2stt2vXLkRERMDW1hZOTk5wd3fH//3
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(x,y_amp)\n",
|
||
|
"plt.xlabel(\"percent of input amplitude\")\n",
|
||
|
"plt.ylabel(\"accuracy\")\n",
|
||
|
"plt.title(\"Amplitude additive Gaussian noise\")\n",
|
||
|
"plt.grid()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "9ddb5d78-ebda-4e08-b14a-a2515e618169",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### phase noise"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "2397c16b-c560-432c-af7f-4b7065d78afa",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import types\n",
|
||
|
"\n",
|
||
|
"def to_class_labels(softmax_distibutions):\n",
|
||
|
" return torch.argmax(softmax_distibutions, dim=1).cpu()\n",
|
||
|
"\n",
|
||
|
"def test(inputs, targets, noise_percent=0.1):\n",
|
||
|
" def opt_conv(self, inputs, heights):\n",
|
||
|
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
|
||
|
" result = result * heights \n",
|
||
|
" result.imag += torch.randn_like(result.imag)*(result.imag.abs().max()*noise_percent)\n",
|
||
|
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
|
||
|
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
|
||
|
" return amplitude\n",
|
||
|
"\n",
|
||
|
" model.opt_conv = types.MethodType(opt_conv, model) \n",
|
||
|
" \n",
|
||
|
" predicted = []\n",
|
||
|
" batch_start = 0\n",
|
||
|
" while batch_start < test_data.shape[0]:\n",
|
||
|
" batch_end = min(batch_start + CONFIG.test_batch_size, test_data.shape[0])\n",
|
||
|
" batch_input = inputs[batch_start:batch_end].cuda() \n",
|
||
|
" with torch.inference_mode():\n",
|
||
|
" batch_output, _ = model(batch_input)\n",
|
||
|
" predicted.append(batch_output.cpu())\n",
|
||
|
" batch_start = batch_end\n",
|
||
|
" \n",
|
||
|
" predicted = torch.concat(predicted)\n",
|
||
|
" \n",
|
||
|
" test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
|
||
|
" return test_acc"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "914fc19a-654c-484c-816f-2774e4ff6f54",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"x = np.arange(0, 1.0, 0.01)\n",
|
||
|
"y_phase = []\n",
|
||
|
"for p in x:\n",
|
||
|
" y_phase.append(np.mean([test(inputs, targets, noise_percent=p) for _ in range(5)]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "a31e2b8d-33f3-487c-a254-1687256508e1",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABp4ElEQVR4nO3deVhU9f4H8PfMwMyw77sIiqKiKIpLuC8oLi1aubSoUdktpSxuN7NF0xbLzLq/sizL5Xq7rlezm6YgSabihuKKuCCLKPsOwgwz5/cHMTkCyjLMgeH9eh6ee+fMOWc+58OJeXvO95wjEQRBABEREZGJkIpdABEREZEhMdwQERGRSWG4ISIiIpPCcENEREQmheGGiIiITArDDREREZkUhhsiIiIyKQw3REREZFIYboiIiMikMNwQAVi/fj0kEglOnjwpdilGU7PNKSkp953X19cXzzzzjO51bGwsJBIJYmNj77tsSkoKJBIJ1q9f3+Ra2ztT66FEIsF7770ndhlkwhhuyKTVfIHX/CiVSvj7+yMiIgJZWVlil2dS/vOf/+CLL74Qu4xaiouL8eGHH6J///6ws7ODQqGAj48Ppk+fjt27d4tdHhG1ADOxCyAyhqVLl6JTp06oqKjAoUOH8M0332DPnj04f/48LC0txS6vzRk+fDhu374NuVyum/af//wH58+fx6uvvqo3r4+PD27fvg1zc3MjVwlcvXoVYWFhSE1NxZQpUzBr1ixYW1sjPT0de/bswYMPPoh//etfmDlzptFrawwxe9gSbt++DTMzfv1Qy+HeRe3ChAkT0L9/fwDA888/DycnJ6xcuRK7du3CE088IXJ1bY9UKoVSqWzQvDVHzIytqqoKU6ZMQVZWFn7//XcMGTJE7/3FixcjKioKGo3G6LU1llg9bCmmtC3UOvG0FLVLo0ePBgBcv35db3plZSUiIyPh4uICKysrTJkyBTk5OXrz7Nq1C5MmTYKnpycUCgX8/Pzw/vvv1/qSvHLlCh577DG4u7tDqVSiQ4cOmDFjBoqKivTm+/e//43g4GBYWFjA0dERM2bMQHp6+n23ITU1FXPnzkW3bt1gYWEBJycnTJ06tc4xNBcuXMDo0aNhYWGBDh064IMPPoBWq601nyAI+OCDD9ChQwdYWlpi1KhRuHDhQq357h5zM3LkSOzevRupqam6U4C+vr4Aao8XWbFiBSQSCVJTU2utd+HChZDL5SgoKNBNO3bsGMaPHw87OztYWlpixIgROHz48H37s23bNpw/fx7vvvturWBTY9y4cZgwYYLudX5+Pl5//XUEBgbC2toatra2mDBhAs6cOaO3XH3jleoai9SQ/SA6OhpDhw6Fvb09rK2t0a1bN7z11lu69+sac3P27Fk888wz6Ny5M5RKJdzd3fHss88iLy9Pr6b33nsPEokEV69exTPPPAN7e3vY2dkhPDwc5eXl9+3jyJEj0atXL1y8eBGjRo2CpaUlvLy8sHz58lrzZmdn47nnnoObmxuUSiX69OmDDRs21Jrv7jE3JSUlePXVV+Hr6wuFQgFXV1eMHTsWp06d0luuqfsCtT88ckPt0rVr1wAATk5OetNffvllODg4YPHixUhJScEXX3yBiIgIbNmyRTfP+vXrYW1tjcjISFhbW+O3337DokWLUFxcjE8//RQAoFKpEBYWhsrKSrz88stwd3dHRkYGfvnlFxQWFsLOzg4A8OGHH+Ldd9/FtGnT8PzzzyMnJwdffvklhg8fjtOnT8Pe3r7ebThx4gSOHDmCGTNmoEOHDkhJScE333yDkSNH4uLFi7rTbZmZmRg1ahSqqqrw5ptvwsrKCt999x0sLCxqrXPRokX44IMPMHHiREycOBGnTp3CuHHjoFKp7tnPt99+G0VFRbhx4wY+//xzAIC1tXWd806bNg1vvPEGtm7din/84x96723duhXjxo2Dg4MDAOC3337DhAkTEBwcjMWLF0MqlWLdunUYPXo0/vjjDwwcOLDemv73v/8BAJ5++ul71n6n5ORk/PTTT5g6dSo6deqErKwsfPvttxgxYgQuXrwIT0/PBq8LaNh+cOHCBTz44IPo3bs3li5dCoVCgatXr973Szs6OhrJyckIDw+Hu7s7Lly4gO+++w4XLlzA0aNHIZFI9OafNm0aOnXqhGXLluHUqVP4/vvv4erqik8++eS+21FQUIDx48fj0UcfxbRp07B9+3YsWLAAgYGBunB4+/ZtjBw5ElevXkVERAQ6deqEbdu24ZlnnkFhYSHmz59f7/pffPFFbN++HREREQgICEBeXh4OHTqExMRE9OvXD0Dz9gVqhwQiE7Zu3ToBgLB//34hJydHSE9PFzZv3iw4OTkJFhYWwo0bN/TmCw0NFbRarW751157TZDJZEJhYaFuWnl5ea3P+dvf/iZYWloKFRUVgiAIwunTpwUAwrZt2+qtLSUlRZDJZMKHH36oN/3cuXOCmZlZrel3q6uOuLg4AYDwr3/9Szft1VdfFQAIx44d003Lzs4W7OzsBADC9evXddPkcrkwadIkvR689dZbAgBh9uzZumkHDhwQAAgHDhzQTZs0aZLg4+NTq6br168LAIR169bppoWEhAjBwcF68x0/flyvdq1WK3Tt2lUICwvTq6e8vFzo1KmTMHbs2Hv2p2/fvoK9vX2t6aWlpUJOTo7up6ioSPdeRUWFoNFoatWvUCiEpUuX6qbV7C81vatxd18ash98/vnnAgAhJyen3nnq6mFdv/9NmzYJAISDBw/qpi1evFgAIDz77LN6806ZMkVwcnKq9zNrjBgxotY+VVlZKbi7uwuPPfaYbtoXX3whABD+/e9/66apVCohJCREsLa2FoqLi3XTAQiLFy/WvbazsxPmzZtXbw3N3Reo/eFpKWoXQkND4eLiAm9vb8yYMQPW1tbYuXMnvLy89OZ74YUX9P7FO2zYMGg0Gr1TKHce8SgpKUFubi6GDRuG8vJyXLp0CQB0R2b27dtX76H/HTt2QKvVYtq0acjNzdX9uLu7o2vXrjhw4MA9t+nOOtRqNfLy8tClSxfY29vrHc7fs2cPHnjgAb1/2bq4uOCpp57SW9/+/fuhUqnw8ssv6/Xg7gHChjB9+nTEx8frjqABwJYtW6BQKPDII48AABISEnDlyhU8+eSTyMvL0/WnrKwMY8aMwcGDB+s8tVajuLi4zqNHb7/9NlxcXHQ/Tz75pO49hUIBqbT6z6JGo0FeXp7uNNHdp0gaoiH7Qc3RuV27dt1ze+525++/oqICubm5eOCBBwCgzlpffPFFvdfDhg1DXl4eiouL7/tZ1tbWekfA5HI5Bg4ciOTkZN20PXv2wN3dXW8Mm7m5OV555RWUlpbi999/r3f99vb2OHbsGG7evFnn+83dF6j9YbihdmHVqlWIjo7GgQMHcPHiRSQnJyMsLKzWfB07dtR7XXN65M4xIBcuXMCUKVNgZ2cHW1tbuLi46P7w14yj6NSpEyIjI/H999/D2dkZYWFhWLVqld44iytXrkAQBHTt2lXvy9bFxQWJiYnIzs6+5zbdvn0bixYtgre3NxQKBZydneHi4oLCwkK9z0lNTUXXrl1rLd+tWze91zUB7u55XVxcdH0wlKlTp0IqlepO9wmCgG3btmHChAmwtbUFUN0fAJg9e3at/nz//feorKysNX7pTjY2NigtLa01fe7cuYiOjkZ0dDTc3Nz03tNqtfj888/RtWtXvZ6ePXv2np9Vn4bsB9OnT8eQIUPw/PPPw83NDTNmzMDWrVvv+2Wdn5+P+fPnw83NDRYWFnBxcUGnTp0AoM5aG7Jv16dDhw61TnM5ODjoLVuzn9WEwxo9evTQvV+f5cuX4/z58/D29sbAgQPx3nvv6QWn5u4L1P5wzA21CwMHDtRdLXUvMpmszumCIAAACgsLMWLECNja2mLp0qXw8/ODUqnEqVOnsGDBAr0vpM8++wzPPPMMdu3ahaioKLzyyitYtmwZjh49ig4dOkCr1UIikeDXX3+t83PrG7NS4+WXX8a6devw6quvIiQkBHZ2dpBIJJgxY0ar/1esp6cnhg0bhq1bt+Ktt97C0aN
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(x,y_phase)\n",
|
||
|
"plt.xlabel(\"percent of input phase\")\n",
|
||
|
"plt.ylabel(\"accuracy\")\n",
|
||
|
"plt.title(\"Phase additive Gaussian noise\")\n",
|
||
|
"plt.grid()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "4afaf4b3-cc22-433e-a601-d97995d99d69",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### both noise"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "ac9c2a17-1e22-489b-ad82-e3fb45435862",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import types\n",
|
||
|
"\n",
|
||
|
"def to_class_labels(softmax_distibutions):\n",
|
||
|
" return torch.argmax(softmax_distibutions, dim=1).cpu()\n",
|
||
|
"\n",
|
||
|
"def test(inputs, targets, noise_percent=0.1):\n",
|
||
|
" def opt_conv(self, inputs, heights):\n",
|
||
|
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
|
||
|
" result = result * heights \n",
|
||
|
" result.real += torch.randn_like(result.real)*(result.real.abs().max()*noise_percent)\n",
|
||
|
" result.imag += torch.randn_like(result.imag)*(result.imag.abs().max()*noise_percent)\n",
|
||
|
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
|
||
|
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
|
||
|
" return amplitude\n",
|
||
|
"\n",
|
||
|
" model.opt_conv = types.MethodType(opt_conv, model) \n",
|
||
|
" \n",
|
||
|
" predicted = []\n",
|
||
|
" batch_start = 0\n",
|
||
|
" while batch_start < test_data.shape[0]:\n",
|
||
|
" batch_end = min(batch_start + CONFIG.test_batch_size, test_data.shape[0])\n",
|
||
|
" batch_input = inputs[batch_start:batch_end].cuda() \n",
|
||
|
" with torch.inference_mode():\n",
|
||
|
" batch_output, _ = model(batch_input)\n",
|
||
|
" predicted.append(batch_output.cpu())\n",
|
||
|
" batch_start = batch_end\n",
|
||
|
" \n",
|
||
|
" predicted = torch.concat(predicted)\n",
|
||
|
" \n",
|
||
|
" test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
|
||
|
" return test_acc"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "cee148b3-af05-468e-b9ff-7987d7fe236e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"x = np.arange(0, 1.0, 0.01)\n",
|
||
|
"y_ampphase = []\n",
|
||
|
"for p in x:\n",
|
||
|
" y_ampphase.append(np.mean([test(inputs, targets, noise_percent=p) for _ in range(5)]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "f2c54d4d-9b8e-4ea2-bc4e-f7423856790e",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABwkklEQVR4nO3dd3hTZfsH8G+SJmnTvQeUlpZZVqEIsgShUGUoqIigjKqIQBWt+iKvA3GAIirvqwiCMl7kJyCCCwRKpQ6mjCKrhUKhUDrpnkmT5/dHaSS0hY40adLv57q4NE+ec8597pw2d895nnMkQggBIiIiIishNXcARERERMbE4oaIiIisCosbIiIisiosboiIiMiqsLghIiIiq8LihoiIiKwKixsiIiKyKixuiIiIyKqwuCEiIiKrwuKGLFpcXBwkEgni4uL0bdOmTUNgYKBJtn/p0iVIJBKsXbvWJNu7WdW+b9myxeTbNpeaPu/aDBkyBEOGDNG/ru9nJZFI8NZbbzUoTqpkTTkMDAzEtGnTzB0G1RGLG6q3zz//HBKJBH379jV3KHVSUlKCt956q05fiOYSFhaGWbNmmTuMFmfHjh3N8su3vLwcn376KQYOHAhXV1coFAr4+fnhgQcewDfffAOtVmvuEImaNRtzB0CWZ8OGDQgMDMThw4eRlJSEdu3amTskA6tWrYJOp9O/LikpwYIFCwDA4C/55iItLQ3Hjx/H22+/be5QrFpAQABKS0shl8v1bTt27MCyZctqLHBKS0thY2P6X5FZWVm4//77cfToUUREROD111+Hm5sb0tPTsWfPHkyaNAlJSUl44403TB5bfZkrh00hMTERUinPB1gK6zjqyGSSk5Oxf/9+bN26FTNmzMCGDRswf/58c4dl4OYvL0vwyy+/wNbWFkOHDjV3KFZNIpHA1ta2zv3r09eYJk+ejOPHj+O7777DQw89ZPDevHnzcOTIESQmJpoltvoyVw6bglKpNHcIVA8sQ6leNmzYAFdXV4waNQqPPPIINmzYUK1P1diGJUuWYNmyZQgKCoJKpcKIESNw5coVCCHwzjvvoHXr1rCzs8ODDz6InJwcg3UEBgZi9OjR2L17N0JDQ2Fra4uQkBBs3br1jjHePObm0qVL8PT0BAAsWLAAEonEYBzAreMyalpHlby8PEybNg3Ozs5wcXHB1KlTkZeXV2MMCQkJeOSRR+Dm5gZbW1v07t0bP/74Y419t2/fjnvvvRd2dnb6mLp27YqjR4+if//+sLOzQ9u2bbFixYoal9fpdHjvvffQunVr2NraYtiwYUhKSjLo88cff2D8+PFo06YNlEol/P398eKLL6K0tNSgX3p6OiIjI9G6dWsolUr4+vriwQcfxKVLlwz6/fLLLxg0aBDs7e3h6OiIUaNG4fTp0zXGd7OcnBy8/PLL6NatGxwcHODk5IT7778fJ06cqNb36tWrGDt2LOzt7eHl5YUXX3wR5eXlNa535cqVCA4Ohp2dHfr06YM//vijWp9bx9xMmzYNy5YtAwD9cSGRSPT9bz5OtmzZAolEgt9++63aer/44gtIJBKcOnVK31afz/9mBw4cwK5du/DMM89UK2yq9O7dG48//rj+tVqtxptvvomwsDA4OzvD3t4egwYNwt69ew2Wq228Uk1jkepyHBw5cgQRERHw8PDQH6NPPvmkwbpvHXNz+fJlzJo1Cx07doSdnR3c3d0xfvz4asfX2rVrIZFIsG/fPkRHR8PT0xP29vYYN24csrKy7pjHadOmwcHBAampqRg7diwcHBzg6emJl19+udolveLiYrz00kvw9/eHUqlEx44dsWTJEgghDPrdOuZGo9FgwYIFaN++PWxtbeHu7o6BAwciJibGYLmGHgvUODxzQ/WyYcMGPPTQQ1AoFJg4cSKWL1+Ov/76C3fddVeNfdVqNZ577jnk5ORg8eLFePTRRzF06FDExcVh7ty5SEpKwqeffoqXX34Zq1evNlj+/PnzmDBhAp599llMnToVa9aswfjx47Fz504MHz68TvF6enpi+fLlmDlzJsaNG6f/wujevXu99lsIgQcffBB//vknnn32WXTu3Bnbtm3D1KlTq/U9ffo0BgwYgFatWuHVV1+Fvb09Nm/ejLFjx+K7777DuHHj9H01Gg327NmDhQsXGqwjNzcXI0eOxKOPPoqJEydi8+bNmDlzJhQKRbUvkPfffx9SqRQvv/wy8vPzsXjxYjz++OM4dOiQvs+3336LkpISzJw5E+7u7jh8+DA+/fRTXL16Fd9++62+38MPP4zTp0/jueeeQ2BgIDIzMxETE4OUlBR9sbd+/XpMnToVERER+OCDD1BSUoLly5dj4MCBOH78+G0Hc1+8eBHff/89xo8fj7Zt2yIjIwNffPEFBg8ejDNnzsDPzw9A5eWMYcOGISUlBc8//zz8/Pywfv16/Prrr9XW+dVXX2HGjBno378/XnjhBVy8eBEPPPAA3Nzc4O/vX2ssM2bMwLVr1xATE4P169fX2g8ARo0aBQcHB2zevBmDBw82eG/Tpk3o0qULunbtCqB+n/+tfvrpJwDAE088cdt4blZQUIAvv/wSEydOxPTp01FYWIivvvoKEREROHz4MEJDQ+u8rip3Og4yMzMxYsQIeHp64tVXX4WLiwsuXbp0xz8+/vrrL+zfvx+PPfYYWrdujUuXLmH58uUYMmQIzpw5A5VKZdD/ueeeg6urK+bPn49Lly5h6dKliIqKwqZNm+64D1qtFhEREejbty+WLFmCPXv24KOPPkJwcDBmzpwJoPLn+oEHHsDevXvx1FNPITQ0FLt27cIrr7yC1NRUfPLJJ7Wu/6233sKiRYvw9NNPo0+fPigoKMCRI0dw7Ngx/e+nxhwL1EiCqI6OHDkiAIiYmBghhBA6nU60bt1azJkzx6BfcnKyACA8PT1FXl6evn3evHkCgOjRo4fQaDT69okTJwqFQiHKysr0bQEBAQKA+O677/Rt+fn5wtfXV/Ts2VPftnfvXgFA7N27V982depUERAQoH+dlZUlAIj58+dX26fBgweLwYMHV2u/dR3ff/+9ACAWL16sb6uoqBCDBg0SAMSaNWv07cOGDRPdunUz2B+dTif69+8v2rdvb7Cd2NhYAUAkJycbxARAfPTRR/q28vJyERoaKry8vIRarTbY986dO4vy8nJ93//85z8CgDh58qS+raSkpNo+Llq0SEgkEnH58mUhhBC5ubkCgPjwww+r9a1SWFgoXFxcxPTp0w3a09PThbOzc7X2W5WVlQmtVmvQlpycLJRKpXj77bf1bUuXLhUAxObNm/VtxcXFol27dgaft1qtFl5eXiI0NNQgBytXrhQADD7bquPy5s9q9uzZorZfg7ceMxMnThReXl6ioqJC35aWliakUqlB7PX5/G81btw4AcDg50YIIUpLS0VWVpb+X25urv69iooKg30XovKz9Pb2Fk8++aS+raafFSGq56Uux8G2bdsEAPHXX3/ddn9uzWFNx+GBAwcEAPG///1P37ZmzRoBQISHhwudTqdvf/HFF4VMJquWn1tNnTpVADD4XIQQomfPniIsLEz/uurn+t133zXo98gjjwiJRCKSkpL0bQEBAWLq1Kn61z169BCjRo26bRyNORaocXhZiupsw4YN8Pb2xr333gug8pTzhAkTsHHjxhpnb4wfPx7Ozs7611Wzq5544gmDQYZ9+/aFWq1GamqqwfJ+fn4Gf9k4OTlhypQpOH78ONLT0426b3eyY8cO2NjY6P/iAwCZTIbnnnvOoF9OTg5+/fVXPProoygsLER2djays7Nx/fp1RERE4Pz58wb7uWPHDoSEhFQ722FjY4MZM2boXysUCsyYMQOZmZk4evSoQd/IyEgoFAr960GDBgGoPEtSpeqSF1B5Gj47Oxv9+/eHEALHjx/X91EoFIiLi0Nubm6NeYiJiUFeXh4mTpyo37fs7GzIZDL07du32qWQWymVSv2gTK1Wi+vXr8PBwQEdO3bEsWPHDPLi6+uLRx55RN+mUqnwzDPPGKzvyJEjyMzMxLPPPmuQg6rLh8Y0YcIEZGZmGlzW2bJ
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(x,y_ampphase)\n",
|
||
|
"plt.xlabel(\"percent of input amplitude and phase\")\n",
|
||
|
"plt.ylabel(\"accuracy\")\n",
|
||
|
"plt.title(\"Amplitude/phase additive Gaussian noise\")\n",
|
||
|
"plt.grid()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "15e629ea-519a-482b-a8eb-83d988a2008e",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### one layer at a time "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "68257ed6-f844-44e1-a864-cf1e87ccc79e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import types\n",
|
||
|
"from utils import pad_zeros, unpad_zeros\n",
|
||
|
"from torchvision.transforms.functional import resize, InterpolationMode\n",
|
||
|
"from einops import rearrange\n",
|
||
|
"\n",
|
||
|
"def to_class_labels(softmax_distibutions):\n",
|
||
|
" return torch.argmax(softmax_distibutions, dim=1).cpu()\n",
|
||
|
"\n",
|
||
|
"def test(inputs, targets, noise_percent=0.1, noise_layer_id=0):\n",
|
||
|
" def forward(self, image):\n",
|
||
|
" image = resize(\n",
|
||
|
" image, \n",
|
||
|
" size=(image.shape[-2]*self.resolution_scale_factor,\n",
|
||
|
" image.shape[-1]*self.resolution_scale_factor),\n",
|
||
|
" interpolation=InterpolationMode.NEAREST\n",
|
||
|
" )\n",
|
||
|
" # debug_out.append(image)\n",
|
||
|
" # 2\n",
|
||
|
" image = pad_zeros(\n",
|
||
|
" image, \n",
|
||
|
" size = (self.phase_mask_size, \n",
|
||
|
" self.phase_mask_size),\n",
|
||
|
" )\n",
|
||
|
" # 3 \n",
|
||
|
" x = image \n",
|
||
|
" for i, plate_heights in enumerate(self.height_maps): \n",
|
||
|
" x = self.opt_conv(x, plate_heights, i)\n",
|
||
|
" convolved = x\n",
|
||
|
" # 4\n",
|
||
|
" grid_to_depth = rearrange(\n",
|
||
|
" convolved,\n",
|
||
|
" \"b 1 (m ht) (n wt) -> b (m n) ht wt\",\n",
|
||
|
" ht = self.tile_size*self.resolution_scale_factor,\n",
|
||
|
" wt = self.tile_size*self.resolution_scale_factor,\n",
|
||
|
" m = self.tiles_per_dim,\n",
|
||
|
" n = self.tiles_per_dim\n",
|
||
|
" )\n",
|
||
|
" # 5\n",
|
||
|
" grid_to_depth = unpad_zeros(grid_to_depth, \n",
|
||
|
" (self.kernel_size_pixels*self.resolution_scale_factor, \n",
|
||
|
" self.kernel_size_pixels*self.resolution_scale_factor))\n",
|
||
|
" max_pool = torch.nn.functional.max_pool2d(\n",
|
||
|
" grid_to_depth,\n",
|
||
|
" kernel_size = self.kernel_size_pixels*self.resolution_scale_factor\n",
|
||
|
" ) \n",
|
||
|
" max_pool = rearrange(max_pool, \"b class_slots 1 1 -> b class_slots\", class_slots=self.class_slots)\n",
|
||
|
" max_pool /= max_pool.max(dim=1, keepdims=True).values\n",
|
||
|
" # 6\n",
|
||
|
" # CrossEntropy already has logsoftmax\n",
|
||
|
" # softmax = torch.nn.functional.log_softmax(max_pool, dim=1)\n",
|
||
|
" return max_pool, convolved\n",
|
||
|
"\n",
|
||
|
" def opt_conv(self, inputs, heights, i):\n",
|
||
|
" result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)\n",
|
||
|
" result = result * heights\n",
|
||
|
" if i == noise_layer_id:\n",
|
||
|
" result.real += torch.randn_like(result.real)*(result.real.max()*noise_percent)\n",
|
||
|
" result.imag += torch.randn_like(result.imag)*(result.imag.max()*noise_percent)\n",
|
||
|
" result = self.propagation(field=result, propagation_distance=self.propagation_distance)\n",
|
||
|
" amplitude = torch.sqrt(result.real**2 + result.imag**2)\n",
|
||
|
" return amplitude\n",
|
||
|
"\n",
|
||
|
" model.forward = types.MethodType(forward, model)\n",
|
||
|
" model.opt_conv = types.MethodType(opt_conv, model) \n",
|
||
|
" \n",
|
||
|
" predicted = []\n",
|
||
|
" batch_start = 0\n",
|
||
|
" while batch_start < test_data.shape[0]:\n",
|
||
|
" batch_end = min(batch_start + CONFIG.test_batch_size, test_data.shape[0])\n",
|
||
|
" batch_input = inputs[batch_start:batch_end].cuda() \n",
|
||
|
" with torch.inference_mode():\n",
|
||
|
" batch_output, _ = model(batch_input)\n",
|
||
|
" predicted.append(batch_output.cpu())\n",
|
||
|
" batch_start = batch_end\n",
|
||
|
" \n",
|
||
|
" predicted = torch.concat(predicted)\n",
|
||
|
" \n",
|
||
|
" test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
|
||
|
" return test_acc"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "30526a54-7653-4112-8bc1-c97611f7fca7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"if CONFIG.layers > 1:\n",
|
||
|
" x = np.arange(0, 1.0, 0.01)\n",
|
||
|
" y_layer = []\n",
|
||
|
" for i in range(CONFIG.layers):\n",
|
||
|
" tmp_y = []\n",
|
||
|
" for p in x:\n",
|
||
|
" tmp_y.append(np.mean([test(inputs, targets, noise_percent=p, noise_layer_id=i) for _ in range(5)]))\n",
|
||
|
" y_layer.append(tmp_y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"id": "c2aca86e-3690-4a3c-a7f9-c7b7f71dd0a4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"if CONFIG.layers > 1:\n",
|
||
|
" for i in range(CONFIG.layers):\n",
|
||
|
" plt.plot(x, y_layer[i], label=f'layer {i}')\n",
|
||
|
" plt.xlabel(\"percent of input amplitude and phase\")\n",
|
||
|
" plt.ylabel(\"accuracy\")\n",
|
||
|
" plt.title(\"One optical layer additive Gaussian noise\")\n",
|
||
|
" plt.grid()\n",
|
||
|
" plt.legend()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "76d3a16e-4f14-4301-9851-0bbeefc92396",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "20449921-ed54-41f8-b21b-0042403a45c5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/tmp/ipykernel_22518/2043102426.py:13: RuntimeWarning: divide by zero encountered in divide\n",
|
||
|
" plt.gca().set_xticks(x_label_values, [f\"{tt:.2f}\" for tt in 20*np.log(1/x_label_values)]);\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAHwCAYAAACG1DoIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAD0KUlEQVR4nOzdd1QUV/vA8e/u0rsKCCLSRLF3sSEYUaxRY4klIvZYYov1TexRf3ajJmpMLDEajcSWWLH33sUuiB0VEAGpO78/eJnXdUEXWET0fs7xHHfmzp1n7u4sz869c0chSZKEIAiCIAiCkC3K/A5AEARBEAShIBJJlCAIgiAIQg6IJEoQBEEQBCEHRBIlCIIgCIKQAyKJEgRBEARByAGRRAmCIAiCIOSASKIEQRAEQRByQCRRgiAIgiAIOSCSKEEQBEEQhBwQSZQgvMP+/ftRKBTs379fr/UqFAomTJig1zr17dSpU9SpUwdzc3MUCgXnz5/P75AEPXF1daVFixb5HYbwDuIc/LCJJEr4qKxYsQKFQiH/MzAwwMnJiaCgIB48ePDe49m2bdsHnyhlJSUlhfbt2xMVFcXcuXNZtWoVLi4u+R3WR+nhw4dMmDBB738gQ0NDmTBhAuHh4XqtV9CUV+0szsEPn0F+ByAIeWHSpEm4ubmRmJjI8ePHWbFiBYcPH+by5cuYmJi8tzi2bdvGTz/9lGki9erVKwwMPtxT8Pbt29y9e5elS5fSq1ev/A7no/bw4UMmTpyIq6srlStX1lu9oaGhTJw4ET8/P1xdXfVWr6Apr9pZnIMfvg/3G1wQcqFp06ZUr14dgF69emFra8v06dPZsmULHTp0yOfo0r3PZC4nIiMjAbCxscnfQPJAfHw85ubm+R2GILzVx3wOfixEd57wSfDx8QHSf9m97tq1a7Rr147ChQtjYmJC9erV2bJlyzvrO3ToEO3bt6dEiRIYGxvj7OzM0KFDefXqlVwmKCiIn376CUCjizFDZmOizp07R9OmTbGyssLCwoKGDRty/PhxjTIZXZZHjhxh2LBh2NnZYW5uTps2bXj69KlO7bF37158fHwwNzfHxsaGVq1acfXqVY3YfX19AWjfvj0KhQI/P78s64uKimL48OFUqFABCwsLrKysaNq0KRcuXNAqm5iYyIQJEyhVqhQmJiY4OjryxRdfaLw3arWaH3/8kQoVKmBiYoKdnR1NmjTh9OnTAISHh6NQKFixYoVW/W+264QJE1AoFISGhtK5c2cKFSpEvXr1ALh48SJBQUG4u7tjYmKCg4MDPXr04Pnz51r1PnjwgJ49e1KsWDGMjY1xc3OjX79+JCcnc+fOHRQKBXPnztXa7ujRoygUCv78889M227//v3UqFEDgO7du8ufk9ePbf369VSrVg1TU1NsbW356quv3tk9vWLFCtq3bw9AgwYN5HrfHNt3+PBhatasiYmJCe7u7vz+++9adcXExDBkyBCcnZ0xNjamZMmSTJ8+HbVa/dYYADZv3kzz5s3ldvPw8GDy5MmkpaVplPPz86N8+fJcvHgRX19fzMzMKFmyJMHBwQAcOHAAb29vTE1NKV26NLt379baly7nT8bnIbP2UigUGl1yGePG3tZGurbzm/R9Dgr5Q1yJEj4JGV+MhQoVkpdduXKFunXr4uTkxOjRozE3N+evv/6idevW/P3337Rp0ybL+tavX09CQgL9+vWjSJEinDx5kgULFnD//n3Wr18PQN++fXn48CEhISGsWrXqnTFeuXIFHx8frKysGDlyJIaGhixZsgQ/Pz/5D8jrvvnmGwoVKsT48eMJDw9n3rx5DBw4kHXr1r11P7t376Zp06a4u7szYcIEXr16xYIFC6hbty5nz57F1dWVvn374uTkxNSpUxk0aBA1atSgaNGiWdZ5584dNm3aRPv27XFzc+PJkycsWbIEX19fQkNDKVasGABpaWm0aNGCPXv20LFjRwYPHszLly8JCQnh8uXLeHh4ANCzZ09WrFhB06ZN6dWrF6mpqRw6dIjjx4/LVxizq3379nh6ejJ16lQkSQIgJCSEO3fu0L17dxwcHLhy5Qq//PILV65c4fjx4/If24cPH1KzZk1iYmLo06cPXl5ePHjwgODgYBISEnB3d6du3bqsXr2aoUOHaux39erVWFpa0qpVq0zjKlOmDJMmTWLcuHH06dNHTvjr1KkDpP+R7t69OzVq1GDatGk8efKEH3/8kSNHjnDu3Lksr1LUr1+fQYMGMX/+fP7zn/9QpkwZeX8Zbt26Rbt27ejZsyfdunVj2bJlBAUFUa1aNcqVKwdAQkICvr6+PHjwgL59+1KiRAmOHj3KmDFjePToEfPmzXtru69YsQILCwuGDRuGhYUFe/fuZdy4ccTGxjJz5kyNstHR0bRo0YKOHTvSvn17Fi1aRMeOHVm9ejVDhgzh66+/pnPnzsycOZN27dpx7949LC0tgeyfP7p6Vxvp0s5vyotzUMgnkiB8RJYvXy4B0u7du6WnT59K9+7dk4KDgyU7OzvJ2NhYunfvnly2YcOGUoUKFaTExER5mVqtlurUqSN5enrKy/bt2ycB0r59++RlCQkJWvueNm2apFAopLt378rLBgwYIGV1mgHS+PHj5detW7eWjIyMpNu3b8vLHj58KFlaWkr169fXOkZ/f39JrVbLy4cOHSqpVCopJibmrW1UuXJlyd7eXnr+/Lm87MKFC5JSqZQCAwO1jnv9+vVvrU+SJCkxMVFKS0vTWBYWFiYZGxtLkyZNkpctW7ZMAqQ5c+Zo1ZFxLHv37pUAadCgQVmWCQsLkwBp+fLlWmXebNfx48dLgNSpUyetspm9j3/++acESAcPHpSXBQYGSkqlUjp16lSWMS1ZskQCpKtXr8rrkpOTJVtbW6lbt25a273u1KlTmR5PcnKyZG9vL5UvX1569eqVvPzff/+VAGncuHFvrXf9+vVan90MLi4uWscZGRkpGRsbS99++628bPLkyZK5ubl048YNje1Hjx4tqVQqKSIi4q0xZNbGffv2lczMzDTOPV9fXwmQ1qxZIy+7du2aBEhKpVI6fvy4vHznzp1a7aXr+ZPxeXhTxnkVFhYmL9O1jd7WzpnJi3NQyB+iO0/4KPn7+2NnZ4ezszPt2rXD3NycLVu2ULx4cSC9+2nv3r106NCBly9f8uzZM549e8bz588JCAjg5s2bb+0uMTU1lf8fHx/Ps2fPqFOnDpIkce7cuWzHm5aWxq5du2jdujXu7u7yckdHRzp37szhw4eJjY3V2KZPnz4a3RI+Pj6kpaVx9+7dLPfz6NEjzp8/T1BQEIULF5aXV6xYkUaNGrFt27Zsxw5gbGyMUqmUj+X58+dYWFhQunRpzp49K5f7+++/sbW15ZtvvtGqI+NY/v77bxQKBePHj8+yTE58/fXXWstefx8TExN59uwZtWrVApDjVqvVbNq0iZYtW2Z6FSwjpg4dOmBiYsLq1avldTt37uTZs2d89dVXOYr59OnTREZG0r9/f40xdM2bN8fLy4utW7fmqN4MZcuWla98AdjZ2VG6dGnu3LkjL1u/fj0+Pj4UKlRIPk+ePXuGv78/aWlpHDx48K37eL2NM841Hx8fEhISuHbtmkZZCwsLOnbsKL8uXbo0NjY2lClTRuNKUsb/M+LMyfmjK13aKDvy6hwU8odIooSP0k8//URISAjBwcE0a9aMZ8+eYWxsLK+/desWkiQxduxY7OzsNP5l/PHOGNSZmYiICPlL0MLCAjs7O3n8wosXL7Id79OnT0lISKB06dJa68qUKYNarebevXsay0uUKKHxOqOrMjo6Osv9ZCRYWe3n2bNnxMfHZzt+tVrN3Llz8fT0xNjYGFtbW+zs7Lh48aJGe9y+fZvSpUu/9a7E27dvU6xYMY0/MPrg5uamtSwqKorBgwdTtGhRTE1NsbOzk8tlxP306VNiY2MpX778W+u3sbGhZcuWrFmzRl62evVqnJyc+Oyzz3IU89veLy8vr7cmzLp48zME6Z+j1z9DN2/eZMeOHVrnib+/P/D28wTSu9natGmDtbU1VlZW2NnZyUnlm+dK8eLFtRJ
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(x, y_amp, label='amplitude all layers')\n",
|
||
|
"plt.plot(x, y_phase, label='phase all layers')\n",
|
||
|
"plt.plot(x, y_ampphase, label='amplitude and phase all layers')\n",
|
||
|
"if CONFIG.layers > 1:\n",
|
||
|
" for i in range(CONFIG.layers):\n",
|
||
|
" plt.plot(x, y_layer[i], label=f'amplitude and phase at layer {i}')\n",
|
||
|
"plt.xlabel(\"Signal to noise ratio\")\n",
|
||
|
"plt.ylabel(\"accuracy\")\n",
|
||
|
"plt.title(f\"Relation of accuracy to the amount of\\nadded Gaussian noise in frequency plane\\nfor the system with {CONFIG.layers} optical layer{'s' if CONFIG.layers >1 else ''}\")\n",
|
||
|
"plt.grid()\n",
|
||
|
"plt.legend()\n",
|
||
|
"x_label_values = np.concatenate([x[::10], [x[-1]]])\n",
|
||
|
"plt.gca().set_xticks(x_label_values, [f\"{tt:.2f}\" for tt in 20*np.log(1/x_label_values)]);\n",
|
||
|
"yticks = np.append(plt.gca().get_yticks()[:-2], [np.array(y_ampphase).min(), np.array(y_ampphase).max()])\n",
|
||
|
"plt.gca().set_yticks(yticks[1:]);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "c1ee6441-6e86-4722-b6d4-2d4b1e026897",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "cfffa146-d7ba-4c51-b42e-4f4be1bd3b9d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "1c59c430-9f30-4aa1-a3fc-9d77174faa98",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "740d676a-0c60-4953-996f-4770aaa57159",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|