You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

456 lines
405 KiB
Plaintext

4 months ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "eb488c3c-8526-472a-9355-ecbe3e778446",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'classes': 16,\n",
" 'classifier_optimization_batch_size': 2560,\n",
" 'classifier_optimization_epochs': 10000,\n",
" 'image_size': 28,\n",
" 'kernel_size': 33,\n",
" 'learning_rate': 5e-05,\n",
" 'pad_amount': 64,\n",
" 'print_every': 10,\n",
" 'save_every': 1000,\n",
" 'summary_every': 10,\n",
" 'test_class_instances': 100,\n",
" 'test_data_path': './assets/quickdraw16_test.npy',\n",
" 'tile_size': 66,\n",
" 'tiles_per_dim': 4,\n",
" 'train_class_instances': 8000,\n",
" 'train_data_path': './assets/quickdraw16_train.npy',\n",
" 'verbose': True}\n",
"\n"
]
}
],
"source": [
"import torch\n",
"from torch import flip\n",
"from torch.nn import Module\n",
"from torch import nn\n",
"from torch.nn.functional import conv2d\n",
"from einops import rearrange\n",
"from torchvision.transforms.functional import resize\n",
"from torch.fft import fft2, ifft2\n",
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
"import tqdm\n",
"import numpy as np\n",
"import math\n",
"import timeit\n",
"import matplotlib.pyplot as plt\n",
"import functools\n",
"from datetime import datetime\n",
"import argparse\n",
"import sys\n",
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
"from pprint import pprint\n",
"from utils import *\n",
"\n",
"CONFIG = type('', (), {})() # object for params\n",
"\n",
"CONFIG.classes = 16\n",
"CONFIG.tiles_per_dim = int(math.sqrt(CONFIG.classes))\n",
"\n",
"CONFIG.image_size = 28\n",
"CONFIG.train_class_instances = 8000\n",
"CONFIG.test_class_instances = 100\n",
"CONFIG.train_data_path = './assets/quickdraw16_train.npy'\n",
"CONFIG.test_data_path = './assets/quickdraw16_test.npy'\n",
"\n",
"CONFIG.classifier_optimization_batch_size = 1280*2\n",
"CONFIG.classifier_optimization_epochs = 10000\n",
"CONFIG.learning_rate = 5e-5\n",
"CONFIG.pad_amount = 64\n",
"CONFIG.kernel_size = 33\n",
"CONFIG.tile_size = CONFIG.kernel_size * 2\n",
"\n",
"CONFIG.classifier_model_path = './classifier.weights'\n",
"\n",
"CONFIG.summary_every = 10\n",
"CONFIG.print_every = 10\n",
"CONFIG.save_every = 1000\n",
"CONFIG.verbose = True\n",
"\n",
"print()\n",
"pprint(CONFIG.__dict__)\n",
"print()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a8360b11-52e1-45a7-bf9f-0ef14fab903d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([128000, 1, 28, 28]), torch.Size([1600, 1, 28, 28]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# change to your directory\n",
"train_data = torch.tensor(np.load(CONFIG.train_data_path), dtype=torch.float32)\n",
"test_data = torch.tensor(np.load(CONFIG.test_data_path), dtype=torch.float32)\n",
"train_data = rearrange(train_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
"test_data = rearrange(test_data, \"b (h w) -> b 1 h w\", h=CONFIG.image_size, w=CONFIG.image_size)\n",
"train_data.shape, test_data.shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d4bfc116-f863-409b-a099-3ead5db95efc",
"metadata": {},
"outputs": [],
"source": [
"train_labels = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.train_class_instances, dim=0)\n",
"test_labels = torch.eye(CONFIG.classes).repeat_interleave(repeats=CONFIG.test_class_instances, dim=0)"
]
},
{
"cell_type": "markdown",
"id": "bb8d9cf6-2565-4d58-b725-2593164c5b0f",
"metadata": {},
"source": [
"# Examples of input data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c01f42d8-da62-41d8-aabe-4e0668af4a9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.Tensor'>\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAFjCAYAAAAzT70yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eZSc13nei/723t9cY1fPjZEAwZkUKYqiKFGyIilSjh3Hg5zo2De27EwrupSybCW+Ho4Sy/axdeO17rWzTmTn3HsSOYmja8eJHSW0ImuwZpGiREmkOEEkiBk9DzV/0977/vEVGgBHAGwADer7rVUL6Orqqt1VX3U937uf93mFtdZSUlJSUlJSUlJS8n2CvNILKCkpKSkpKSkpKbmclAK4pKSkpKSkpKTk+4pSAJeUlJSUlJSUlHxfUQrgkpKSkpKSkpKS7ytKAVxSUlJSUlJSUvJ9RSmAS0pKSkpKSkpKvq8oBXBJSUlJSUlJScn3FaUALikpKSkpKSkp+b6iFMAlJSUlJSUlJSXfV5QCuKSkpKSkpKSk5PuKKyqAP/rRj7J3716CIODuu+/moYceupLLKSkpKSkpKSkp+T5AWGvtlXjgP/mTP+FnfuZn+Df/5t9w991383u/93v86Z/+KQcPHmRqauolf9YYw6lTp6jVagghLtOKS0pKSrYeay3dbpe5uTmkLDflSkpKXpw4jknT9Eov4xw8zyMIgiu9jAvmigngu+++m7vuuot//a//NVCI2l27dvGBD3yAX/7lX37Jnz1x4gS7du26HMssKSkpuSwcP36cnTt3XulllJSUbFPiOOaaPVUWlvSVXso5zMzMcPjw4atOBDtX4kHTNOXhhx/mV37lVzavk1Lyjne8gwceeOB5t0+ShCRJNr8+rdnv5QdxcC/9gktKSkouETkZX+GT1Gq1K72UkpKSbUyapiwsaQ4/vId6bXvsFnW6hmvuPEqapqUAPh9WVlbQWjM9PX3O9dPT0zz11FPPu/1HPvIRfv3Xf/151zu4OKIUwK96hEAoBUKCFIXtRUqE44BSiMDHjtWxnoMYJLDegTzHDoeYOL7Sqy+5XAgBQiIrEWJmEht6xfHQ7kGWYuNkex4Poz240s5VUlJyPtRrctsI4KuZq+IZ/JVf+RXa7fbm5fjx41d6SSWXEaEUIgyRYYCMIkS1gqhVEeNjMDNBtneatde2WLynQee2CczOKZgaR9RqhSgq+b5AKIVwHeT4GN1bJ1m+a4zeLZPYHZMwNYGoVsrjoaSk5KpHW7OtLhfCRz7yEe666y5qtRpTU1P86I/+KAcPHjznNm9961sRQpxz+cf/+B+fc5tjx47xQz/0Q0RRxNTUFL/4i79InucXtJYrUgGemJhAKcXi4uI51y8uLjIzM/O82/u+j+/7l2t5JefLaTEh5Ogfsfl/AKQY3ayo2J7zM1KeW/GSZ/1fPOe8zHMRgQ9CYB1V3JeS6IqP8RTpmEfckqQ1kLkkaPooV+IM4uK+7PbyS5WMEALhuOdU9QEwZtPmhNZYPXr9XqpdQQiE7yMCH1ONGI5JknGBzBVBxUNZi3TL3aKSkpKrH4PFcEXat57Hha7ji1/8Ivfddx933XUXeZ7zq7/6q7zzne/kiSeeoFKpbN7uH/7Df8hv/MZvbH4dRdHm/7XW/NAP/RAzMzN87WtfY35+np/5mZ/BdV1++7d/+7zXckUEsOd53HnnnXzuc5/jR3/0R4GiCe5zn/sc73//+6/EkkouBCGQYYjwfXAcRDXCug42cDGRh5UCHTroQGIcQR5ItCewCvIQjCMwLugArATjWYxX7ARbBUiLFcDoYqUFBVbYzesQFhyLUBbHTxlvrjHppZxYa7J+cwWn7zP1cERtOMQOY2ySYC/w7LDk0qJaY2Q37iZteuhAkEUSK8EZWtyBQaaG4GQPtbqBzTJMu4vNXrj7WdVqxK8/QGePx3BKkL2mx1yrw9FnpzBOSLDhU7MWsbxSHgclJSVXNQbDhdVdLx2nV9LpdM65/sUKl5/61KfO+foP//APmZqa4uGHH+Ytb3nL5vVRFL1gQRTg05/+NE888QSf/exnmZ6e5vbbb+c3f/M3+aVf+iU+/OEP43neea39ighggA9+8IO8973v5XWvex2vf/3r+b3f+z36/T4/93M/d6WWVHK+CImIQkQlwkYB6XgFHSiymkPclBgXspogq4FxLVnDYMMcFWga9T51L2Mi7LG/ukKkUnZ46+x2VwlERk3GRCJHCUsg7Eu2OKpRBVkCCoEUgkfSkM9fdxOHBhN8c3gLtUfrhX9Y61L4bDfGGqzeGjKYEWQ1gx1PEMpi1z3cDYUzcBgP6lQkiEGCGAxfVACLSsTKbT791w6ZnWjzoWv/grv9dX65+XY+v3476bIiWItwnru7UFJSUnKVoa1FX5kAr+dxeh3PTeb6tV/7NT784Q+/7M+3220AWq3WOdf/p//0n/ijP/ojZmZm+OEf/mH++T//55tV4AceeIBbb731nD6yd73rXbzvfe/j8ccf54477jivtV8xAfye97yH5eVl/sW/+BcsLCxw++2386lPfep5jXEllxkhEJ6HbDYQngdKYl1n1GAkiq8dSV4LyEOF8SXxmEJ7gjwSpDWwDmRVi64arLKoWoYfpHiOphUNCZ2Mhhfjyxxf5CgMqVUAuHYkUi1k4vnWBY3AWHHO12fTNQHTbpssVHy1ZUj2tHDbFZSU2OUVrLFgSkvEFUMIpO8jPA/TrBC3BMlkjqjkVBtDlLB0BaSuSz6QdHc4IOq4HU1gLKLTxSYpptc71xLhOOQRNBt9dtY2mFEdxlREVSVYWew0aE/iVSvY1EW4RQMlQiBct7BfjKw1CAG5LqwYeY7t9jDD4UtbMEpKSkouE9vRAnH8+HHq9frm9edjWzXG8PM///O86U1v4pZbbtm8/qd+6qfYs2cPc3NzPProo/zSL/0SBw8e5M/+7M8AWFhYeMEQhdPfO1+umAAGeP/7319aHrYRp1MV5N5dLPy1SYZTgrxiyZs5uBYV5ARBhpSGwO3jKY2rNBNuiidzApVTcVJcqamohKpKMAgyo8hscelmAZmVDLXLo+0d5EaSGUWSF4fiaXFrgUwrtBFYK9BGYgGtJVpLrBUYXVyHFZz24V87t8zP7fwqe6srPHP3JN+c3oNZrbDzryJqXwebpph2p6wGXyGk78O1e0mmKqxf5zH25gV+au6J4vgwCo3EFRolDAPtcfiucVaGVY6uNIkeniNcstSOJ7gPP43pdjeTH2zoM9yZ8w+u+SY7vVXmnBzwyaxC5iA1JE0H98bdIATJmEceSXJfkDQF2i8sOVnNYiW4fYHTB7drmfpGB/HYM2f8yKUQLikpKTmHer1+jgA+H+677z4ee+wxvvKVr5xz/T/6R/9o8/+33nors7OzvP3tb+fQoUPs379/S9YLV1gAl2wjhCgixZQin6iycathet8K19TXeMvY95h0uux3l9njaFwh8YWLHFVf1QtsK5/uDE1szuFcs6CrLOd1Hh3sYiOPWBj6zHfrJLli0A/QfQeMQGgBBoQWyEQgLAhdXLACmYGTAxZkDsJQ3H4kgJ9+3RTR7oTr3VX+8cwXODI+wZc3ruObh26j+niEGCro9aEUwFcE4XkkUxW6uz16e+Ef73mQf9Q4xYm8x8GsQWYd9jjr7HEclBCb22t/3p/ln+c/QnzCR1iP1mMe9ArxK6TAei5+a8iP1B4lEjAmQ7Q1GCthdIxkkWAwG2AcwWBKktYhr1iyqRQnyqnXBtzeWsaRmmfbEyyt1+ivBdSPR4RPFbsgmw15JSUlJVcIg0VvswrwhfL+97+f+++/ny996UsvOwDo7rvvBuCZZ55h//79zMzM8NBDD51zm9OhCi/mG34hSgG8Bah6HTHRwroOIsshyyHPMe0OZjC40su7YKwYNZydxWmrgTsSu4t6yMAKNozHE8kMGzoiMS4D46GtJDEOQ+2RGIfFuEY39RlmLuvdiDxTmKGD7CtEDioWeEOBGAnZ0xeZnS1wLZz9RhNg3KKxzjqQe8U2t3INGzpi2RTbL5NOh2m/Q9qEdEcTp50
"text/plain": [
"<Figure size 1040x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"examples = []\n",
"for i in range(CONFIG.classes):\n",
" examples.append(train_data[i*8000+2,0])\n",
" \n",
"examples = torch.stack(examples)\n",
"examples = rearrange(examples, \"(a b) h w -> (a h) (b w)\", a=CONFIG.tiles_per_dim, b=CONFIG.tiles_per_dim)\n",
"\n",
"imshow(examples);"
]
},
{
"cell_type": "markdown",
"id": "e7c47a7f-601d-434d-a350-2e2956b250d3",
"metadata": {},
"source": [
"# Baseline classifier"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "99b615ad-e107-420b-89ba-4e63b219f6a9",
"metadata": {},
"outputs": [],
"source": [
"class CNNClassificator(Module):\n",
" def __init__(self, input_size, kernel_size, filters):\n",
" super().__init__()\n",
" self.conv_weights = torch.nn.Parameter(\n",
" torch.empty([filters, 1, kernel_size, kernel_size])\n",
" )\n",
" torch.nn.init.xavier_normal_(self.conv_weights)\n",
" self.max_pool_layer = torch.nn.MaxPool2d(kernel_size=input_size)\n",
" self.softmax = torch.nn.Softmax(dim=1)\n",
" \n",
" def forward(self, x):\n",
" x = torch.nn.functional.conv2d(x, self.conv_weights.abs(), padding='same')\n",
" x = self.max_pool_layer(x)\n",
" x = x.squeeze()\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c1dcf589-6923-4ef6-9913-19c82ea37d9d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNNClassificator(\n",
" (max_pool_layer): MaxPool2d(kernel_size=torch.Size([28, 28]), stride=torch.Size([28, 28]), padding=0, dilation=1, ceil_mode=False)\n",
" (softmax): Softmax(dim=1)\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classificator = CNNClassificator(\n",
" input_size = train_data.shape[-2:], \n",
" kernel_size = CONFIG.kernel_size, \n",
" filters = CONFIG.classes\n",
")\n",
"# classificator.load_state_dict(torch.load(CONFIG.classifier_model_path))\n",
"classificator.eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "039566c4-d448-491c-afb8-3c8cef67f59a",
"metadata": {},
"outputs": [],
"source": [
"classificator = classificator.cuda()\n",
"inputs = train_data.cuda()\n",
"targets = train_labels.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "979aa425-7a38-42d8-8eff-6f608f0f109a",
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.AdamW(params=classificator.parameters(), lr=1e-5)\n",
"scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)\n",
"loss_function = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "78fc33f0-74b9-4068-a351-e852dae7dd61",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [03:45<00:00, 44.30it/s, loss: 7.228, acc: 0.75, lr: 3.697296e-06]\n"
]
}
],
"source": [
"# training loop\n",
"\n",
"epochs = CONFIG.classifier_optimization_epochs\n",
"batch_size = CONFIG.classifier_optimization_batch_size\n",
"ppp = tqdm.trange(epochs)\n",
"\n",
"for epoch in ppp:\n",
" batch_idxs = torch.randint(low=0, high=train_data.shape[0], size=[batch_size]) \n",
" batch_inputs = inputs[batch_idxs]\n",
" batch_targets = targets[batch_idxs]\n",
" # apply model\n",
" predicted = classificator(batch_inputs)\n",
"\n",
" # correct model\n",
" loss_value = loss_function(predicted, batch_targets)\n",
" loss_value.backward()\n",
" optimizer.step()\n",
"\n",
" if epoch % 100 == 0:\n",
" acc = accuracy_score(to_class_labels(batch_targets), to_class_labels(predicted))\n",
" ppp.set_postfix_str(\"loss: {:.3f}, acc: {:.2f}, lr: {:e}\".format(loss_value, acc, scheduler.get_last_lr()[0]))\n",
" \n",
" if (scheduler.get_last_lr()[0] > 1e-6) and (epoch%100==0):\n",
" scheduler.step()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5e67af52-1bb8-4e05-a0f5-99fb7ccaf62e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Accuracy on test dataset: ', 0.73875)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = test_data.cuda()\n",
"targets = test_labels.cuda()\n",
"predicted = classificator(inputs)\n",
"test_acc = accuracy_score(to_class_labels(targets), to_class_labels(predicted))\n",
"\"Accuracy on test dataset: \", test_acc"
]
},
{
"cell_type": "markdown",
"id": "fe99adff-29fe-46bd-90c1-218aa58c79d0",
"metadata": {},
"source": [
"# Learned weights"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "88a0f7f9-0728-4dd9-b8df-9bc1e14f7f5a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 1, 264, 264])\n",
"<class 'torch.Tensor'>\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQUAAAU1CAYAAACqeqKtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5jU1f3+/9fSi+6itKWD9CYoZV0R6UXAiIoCQQWsiZVgSeLHoCYmJpaIHTX2rgliQ0RBBGUBaSoCCkjviOzS+++P/HznvO5hV83XiOt5Pq7LK+fmDLMzw7zLvjPnnrSDBw8eNAAAAAAAAADRKHK4HwAAAAAAAACAHxcXBQEAAAAAAIDIcFEQAAAAAAAAiAwXBQEAAAAAAIDIcFEQAAAAAAAAiAwXBQEAAAAAAIDIcFEQAAAAAAAAiAwXBQEAAAAAAIDIcFEQAAAAAAAAiAwXBQEAAL6Hjh07WseOHf+rv5uWlmY33XRTgbeZNGmSpaWl2T//+c//6mccTk888YSlpaXZsmXLDvdDAQAAwLfgoiAAAPjJ+vTTT61fv35Wq1YtK1WqlFWrVs26detm99577//0586fP99uuukmLm4BAADgZ6vY4X4AAAAAhzJ16lTr1KmT1axZ0y666CLLzMy0lStX2rRp0+zuu++2K6644n/2s+fPn28333yzdezY0WrXru3mxo8f/z/7uQAAAMCPhYuCAADgJ+nPf/6zZWRk2EcffWTlypVzcxs2bDg8D8rMSpQocdh+9g/h4MGDtmvXLitduvThfigAAAA4jFg+DAAAfpKWLFliTZs2TbkgaGZWqVIll9PS0uzyyy+3Z5991ho2bGilSpWyVq1a2eTJk93tli9fbpdeeqk1bNjQSpcubeXLl7ezzjrLLRN+4okn7KyzzjIzs06dOllaWpqlpaXZpEmTzCy1U3DPnj02YsQIa9WqlWVkZFjZsmWtffv29t577/0gr4OZ2e7du61Pnz6WkZFhU6dONTOzAwcO2MiRI61p06ZWqlQpq1y5sl1yySX29ddfu79bu3Zt69Onj7399tvWunVrK126tD300ENJd+FLL71kf/7zn6169epWqlQp69Kliy1evDjlMUyfPt169uxpGRkZVqZMGevQoYN9+OGHP9hzBAAAwI+LTwoCAICfpFq1allOTo7NmzfPmjVr9q23f//99+3FF1+0K6+80kqWLGkPPPCA9ezZ02bMmJH8/Y8++simTp1qAwYMsOrVq9uyZcvswQcftI4dO9r8+fOtTJkydvLJJ9uVV15p99xzj11//fXWuHFjM7Pkf1VeXp794x//sIEDB9pFF11kW7dutUcffdR69OhhM2bMsJYtW/4/vQ47d+600047zWbOnGnvvvuutWnTxszMLrnkEnviiSds6NChduWVV9rSpUvtvvvuszlz5tiHH35oxYsXT+7j888/t4EDB9oll1xiF110kTVs2DCZ++tf/2pFihSxa665xnJzc+22226zQYMG2fTp05PbTJw40U455RRr1aqV3XjjjVakSBF7/PHHrXPnzjZlyhRr27bt/9NzBAAAwI+Pi4IAAOAn6ZprrrFTTjnFWrZsaW3btrX27dtbly5drFOnTu6C1zfmzZtnM2fOtFatWpmZ2YABA6xhw4Y2YsQIGz16tJmZ9e7d2/r16+f+3qmnnmrZ2dn2r3/9y84991w75phjrH379nbPPfdYt27dvvWbho866ihbtmyZW1Z80UUXWaNGjezee++1Rx999L9+DbZt22Z9+vSxzz77zCZOnJhcYPzggw/sH//4hz377LP2y1/+Mrl9p06drGfPnvbyyy+7P1+8eLGNGzfOevTokfzZN5983LVrl82dOzd5/EcddZRdddVVycXYgwcP2q9+9Svr1KmTvfXWW5aWlmZm/74o2bRpU7vhhhvoWQQAACiEWD4MAAB+krp162Y5OTn2i1/8wj7++GO77bbbrEePHlatWjV77bXXUm6fnZ2dXBA0M6tZs6addtpp9vbbb9v+/fvNzFyP3t69e+2rr76yevXqWbly5Wz27Nn/1eMsWrRockHtwIEDtnnzZtu3b5+1bt36v75PM7Pc3Fzr3r27LVy40CZNmuQ+cfjyyy9bRkaGdevWzTZt2pT816pVKzviiCNSli7XqVPHXRAMDR061F3QbN++vZmZffnll2ZmNnfuXFu0aJH98pe/tK+++ir5Wdu3b7cuXbrY5MmT7cCBA//18wQAAMDhwScFAQDAT1abNm1s9OjRtmfPHvv444/tlVdesbvuusv69etnc+fOtSZNmiS3rV+/fsrfb9Cgge3YscM2btxomZmZtnPnTrv11lvt8ccft9WrV9vBgweT2+bm5v7Xj/PJJ5+0O++80xYuXGh79+5N/rxOnTr/9X0OGzbMdu3aZXPmzLGmTZu6uUWLFllubm5Kt+I39ItYCnocNWvWdPmoo44yM0u6CRctWmRmZoMHD873PnJzc5O/BwAAgMKBi4IAAOAnr0SJEtamTRtr06aNNWjQwIYOHWovv/yy3Xjjjd/rfq644gp7/PHHbdiwYZadnW0ZGRmWlpZmAwYM+K8/7fbMM8/YkCFDrG/fvnbttddapUqVrGjRonbrrbfakiVL/qv7NDM77bTT7IUXXrC//vWv9tRTT1mRIv9Z4HHgwAGrVKmSPfvss4f8uxUrVnS5oG8aLlq06CH//JsLpt+8Lrfffnu+/YhHHHFEvvcPAACAnyYuCgIAgEKldevWZma2du1a9+fffKIt9MUXX1iZMmWSi2T//Oc/bfDgwXbnnXcmt9m1a5dt2bLF/b1vevO+i3/+8592zDHH2OjRo93f+74XLFXfvn2te/fuNmTIEDvyyCPtwQcfTObq1q1r7777rrVr167AC34/hLp165qZWXp6unXt2vV/+rMAAADw46FTEAAA/CS99957bnnvN8aOHWtm5r5B18wsJyfHdfitXLnSXn31VevevXvyabiiRYum3Oe9996bdA5+o2zZsmZmKRcLD+Wb+w7vd/r06ZaTk/Otf/fbnHfeeXbPPffYqFGj7Le//W3y52effbbt37/f/vSnP6X8nX379n2nx/1dtWrVyurWrWt33HGHbdu2LWV+48aNP9jPAgAAwI+HTwoCAICfpCuuuMJ27Nhhp59+ujVq1Mj27NljU6dOtRdffNFq165tQ4cOdbdv1qyZ9ejRw6688korWbKkPfDAA2ZmdvPNNye36dOnjz399NOWkZFhTZo0sZycHHv33XetfPny7r5atmxpRYsWtb/97W+Wm5trJUuWtM6dOx+yw69Pnz42evRoO/3006137962dOlSGzVqlDVp0uSQF9G+r8svv9zy8vLs//7v/ywjI8Ouv/5669Chg11yySV266232ty5c6179+5WvHhxW7Rokb388st29913p3zL8n+rSJEi9o9//MNOOeUUa9q0qQ0dOtSqVatmq1evtvfee8/S09Pt9ddf/0F+FgAAAH48XBQEAAA/SXfccYe9/PLLNnbsWHv44Ydtz549VrNmTbv00kvthhtusHLlyrnbd+jQwbKzs+3mm2+2FStWWJMmTeyJJ56wY489NrnN3XffbUWLFrVnn33Wdu3aZe3atbN333035Zt5MzMzbdSoUXbrrbfaBRdcYPv377f33nvvkBcFhwwZYuvWrbOHHnrI3n77bWvSpIk988wz9vLLL9ukSZN+kNfi+uuvt9zc3OTC4GWXXWajRo2yVq1a2UMPPWTXX3+9FStWzGrXrm3nnHOOtWvX7gf5ud/o2LGj5eTk2J/+9Ce77777bNu2bZaZmWlZWVl2ySWX/KA/CwAAAD+OtIOHWpcDAABQiKSlpdlll11m99133+F+KAAAAEChQKcgAAAAAAAAEBkuCgIAAAAAAACR4aIgAAAAAAAAEBm+aAQAABR6VCQDAAAA3w+fFAQAAAAAAAAiw0VBAAAAAAAAIDJcFAQAAAAAAAAiw0VBAAAAAAAAIDJcFAQAAAAAAAAiw0VBAAAAAAAAIDJcFAQAAAAAAAAiw0VBAAAAAAAAIDJcFAQAAAAAAAAiw0VBAAAAAAA
"text/plain": [
"<Figure size 1500x1500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"spatial_kernel = rearrange(\n",
" pad_zeros(classificator.conv_weights.abs(), \n",
" size=(CONFIG.tile_size, CONFIG.tile_size)),\n",
" \"(a b) 1 h w -> 1 1 (a h) (b w)\", a=int(math.sqrt(CONFIG.classes)), b=int(math.sqrt(CONFIG.classes))\n",
")\n",
"spatial_kernel = spatial_kernel/spatial_kernel.sum(dim=[2,3], keepdim=True) # conservation of energy\n",
"spatial_kernel = spatial_kernel.detach()\n",
"\n",
"print(spatial_kernel.shape)\n",
"\n",
"imshow(spatial_kernel[0,0], cmap='gray', title=\"Spatial kernel\", figsize=(15,15));"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fb077015-1dd0-47f2-af62-9864f97cab0b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.Tensor'>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_10191/3562513726.py:4: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /opt/conda/conda-bld/pytorch_1712608935911/work/aten/src/ATen/native/Convolution.cpp:1031.)\n",
" out = conv2d(spatial_kernel, image, padding='same')[0,0].cpu()\n",
"/tmp/ipykernel_10191/3562513726.py:4: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608935911/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
" out = conv2d(spatial_kernel, image, padding='same')[0,0].cpu()\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq4AAAGHCAYAAABxtNBlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9eZxlVXXut/be59xbQw803YCN0EAziiKESYYWVAgi4kOjKBoFgokxD4wiRokvMmjkGY1oFDFGBUNe4kASNQ6gOIJDoqLGiVEMCjI1NF3Tvefsvdf7Y+29z7lV1dBV3dDV7f5+v/pV1al7z7DvPnW+/a1vrUXMzMjIyMjIyMjIyMhY4FBb+gQyMjIyMjIyMjIyNgaZuGZkZGRkZGRkZGwVyMQ1IyMjIyMjIyNjq0AmrhkZGRkZGRkZGVsFMnHNyMjIyMjIyMjYKpCJa0ZGRkZGRkZGxlaBTFwzMjIyMjIyMjK2CmTimpGRkZGRkZGRsVUgE9eMjIyMjIyMjIytApm4Pk4gIlx44YWPy7GuueYaHHjggeh2uyAirFu3bk7vv/DCC0FEA9t22203nHHGGfM6n2OPPRbHHnvsvN67Icx2jtsyNmX8N4Svf/3rICJ8/etf36z7fazwq1/9CkSEK6+8cqNf+653veuxP7HHAVdeeSWICL/61a/m/N4zzjgDu+2226O+brfddsNzn/vcuZ/cAsDj+f81IyNjy2KbIK7xn3r8MsZg5513xhlnnIG77rprS5/erPj2t7+NCy+8cM6k8tGwdu1anHrqqRgaGsJll12Gq666CiMjI5v1GJuKu+++GxdeeCF+9KMfbelT+Z3ABz7wgY0ie1sjvvCFL2xThOXtb387Pv3pT2/p08jIyMhYsDBb+gQ2Jy6++GLsvvvu6PV6+O53v4srr7wSN9xwA37605+i2+1u6dMbwLe//W1cdNFFOOOMM7B06dLNtt/vfe97GBsbw1vf+lYcd9xxm22/N998M5Sa3zrnS1/60sDvd999Ny666CLstttuOPDAAzfD2WU8Ej7wgQ9g+fLlMxTbpz/96ZiamkJZllvmxOaIVatWYWpqCkVRpG1f+MIXcNlll20z5PXtb387XvjCF+KUU04Z2P7yl78cL3nJS9DpdLbMiWVkZGQsEGxTxPXEE0/EIYccAgB45StfieXLl+Md73gHPvvZz+LUU0/dwmf3+OC+++4DgM1KhgFs0gNzayFGv2tQSi24Bd0jgYi2qvPdnNBaQ2u9pU9jkzAxMbHgoj8ZGRlbH7YJq8CGsGbNGgDA7bffPrD9pptuwgtf+EIsW7YM3W4XhxxyCD772c8OvKaua1x00UXYa6+90O12sf322+Poo4/Gl7/85fSaDXk3H81TduGFF+INb3gDAGD33XdPFodH86996lOfwsEHH4yhoSEsX74cf/iHfzhghTj22GNx+umnAwAOPfRQENGj+iJvuOEGHHrooeh2u1i9ejX+/u//ftbXzeax/O///m8cc8wxGBoawhOf+ES87W1vwxVXXDHjWtrj9PWvfx2HHnooAODMM89M1x5D2ddffz1e9KIXYdddd0Wn08Euu+yC173udZiamnrE63gk/Od//iee85znYLvttsPIyAgOOOAAvPe97x14zVe/+lWsWbMGIyMjWLp0Kf7X//pf+MUvfjHwmuirve2225JSvmTJEpx55pmYnJxMr3vyk5+MZzzjGTPOw3uPnXfeGS984QvTtomJCbz+9a/HLrvsgk6ng3322Qfvete7wMyPeE0b8vhO90Lutttu+NnPfoZvfOMbaazbn8VsHtdHm2eAzPHR0VHcddddOOWUUzA6OooVK1bgvPPOg3PuEc/93HPPxfbbbz9wjeeccw6ICH/3d3+Xtt17770gIlx++eUAZnpczzjjDFx22WUAMGAVmo4PfehDWL16NTqdDg499FB873vfe8TzAzbu/o9j8Mtf/hInnHACRkZGsHLlSlx88cUzPr93vetdOPLII7H99ttjaGgIBx98MK6++uqB1xARJiYm8LGPfSxdS7znZvO4fuYzn8FJJ52ElStXotPpYPXq1XjrW9/6qOM/F3zsYx+DMSb9vwLkfnr2s5+NJUuWYHh4GMcccwy+9a1vDbwvzs+f//zneOlLX4rtttsORx99NIDGS3vDDTfgsMMOQ7fbxR577IF//Md/nHH8devW4bWvfW26P/bcc0+84x3vgPd+s11jRkbG1oVtSnGdjvhPfrvttkvbfvazn+Goo47CzjvvjDe96U0YGRnBJz/5SZxyyin413/9Vzz/+c8HIP94L7nkErzyla/EYYcdhvXr1+P73/8+brzxRhx//PGbdF4veMELcMstt+Bf/uVfcOmll2L58uUAgBUrVmzwPVdeeSXOPPNMHHroobjkkktw77334r3vfS++9a1v4Yc//CGWLl2KN7/5zdhnn33woQ99KNkmVq9evcF9/uQnP8Hv//7vY8WKFbjwwgthrcUFF1yAHXfc8VGv4a677sIznvEMEBHOP/98jIyM4MMf/vCjKrP77bcfLr74YrzlLW/Bn/zJn6TFxZFHHglASNPk5CRe/epXY/vtt8d//dd/4X3vex9+85vf4FOf+tSjntd0fPnLX8Zzn/tcPOEJT8Cf//mfY6eddsIvfvELfO5zn8Of//mfAwCuu+46nHjiidhjjz1w4YUXYmpqCu973/tw1FFH4cYbb5yxCDn11FOx++6745JLLsGNN96ID3/4w9hhhx3wjne8AwDw4he/GBdeeCHuuece7LTTTul9N9xwA+6++2685CUvAQAwM573vOfha1/7Gs466ywceOCBuPbaa/GGN7wBd911Fy699NI5X+90vOc978E555yD0dFRvPnNbwaAR/x8N2aeRTjncMIJJ+Dwww/Hu971Llx33XX427/9W6xevRqvfvWrN3iMNWvW4NJLL8XPfvYzPPnJTwYgCxalFK6//nq85jWvSdsAsTTMhle96lW4++678eUvfxlXXXXVrK/553/+Z4yNjeFVr3oViAh/8zd/gxe84AX45S9/OWA5mI6Nvf+dc3j2s5+Npz3tafibv/kbXHPNNbjgggtgrcXFF1+cXvfe974Xz3ve8/Cyl70MVVXh4x//OF70ohfhc5/7HE466SQAwFVXXZWO9yd/8icA8Ij375VXXonR0VGce+65GB0dxVe/+lW85S1vwfr16/HOd75zg+/bWHzoQx/Cn/7pn+Iv//Iv8ba3vQ2ALPBOPPFEHHzwwbjggguglMIVV1yBZz7zmbj++utx2GGHDezjRS96Efbaay+8/e1vHyDzt912G174whfirLPOwumnn46PfvSjOOOMM3DwwQdj//33BwBMTk7imGOOwV133YVXvepV2HXXXfHtb38b559/Pn7729/iPe95zyZfY0ZGxlYI3gZwxRVXMAC+7rrr+P777+df//rXfPXVV/OKFSu40+nwr3/96/TaZz3rWfyUpzyFe71e2ua95yOPPJL32muvtO2pT30qn3TSSY943GOOOYaPOeaYGdtPP/10XrVq1cA2AHzBBRek39/5zncyAL7jjjse9fqqquIddtiBn/zkJ/PU1FTa/rnPfY4B8Fve8pa0LY7F9773vUfd7ymnnMLdbpf/53/+J237+c9/zlprnj41Vq1axaeffnr6/ZxzzmEi4h/+8Idp29q1a3nZsmUzrmv6OH3ve99jAHzFFVfMOKfJyckZ2y655BImooHzvOCCC2ac43RYa3n33XfnVatW8UMPPTTwN+99+vnAAw/kHXbYgdeuXZu2/fjHP2alFL/iFa+Yccw/+qM/GtjX85//fN5+++3T7zfffDMD4Pe9730Dr/uzP/szHh0dTdf46U9/mgHw2972toHXvfCFL2Qi4ttuuy1tmz7+G7r++Pm3x3///fefdZ5+7WtfYwD8ta99jZnnNs9OP/10BsAXX3zxwD4POuggPvjgg2ccq4377ruPAfAHPvABZmZet24dK6X4RS96Ee+4447pda95zWt42bJl6bO64447Zsyb//2///es4xBfu/322/ODDz6Ytn/mM59hAPwf//Efj3i
"text/plain": [
"<Figure size 1040x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class_id = 2\n",
"# switch kernel and weights because image is smaller than one filter\n",
"image = inputs[class_id*CONFIG.test_class_instances:class_id*CONFIG.test_class_instances+1]\n",
"out = conv2d(spatial_kernel, image, padding='same')[0,0].cpu()\n",
"out = out.abs()\n",
"out = out/out.sum(dim=[0,1], keepdim=True)\n",
"\n",
"f, ax = imshow(out.abs(), title=\"Result of digital convolution with spatial kernel\")\n",
"y,x = (out==torch.max(out)).nonzero()[0]\n",
"ax[0].text(x,y, \"max\", color='white');"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "25d8a6c1-4c77-482c-a9c6-94de11f3eee8",
"metadata": {},
"outputs": [],
"source": [
"torch.save(classificator.state_dict(), CONFIG.classifier_model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea950ec2-a471-4bdc-990d-2d5eee8a0153",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}