noise experiment

master
Vladimir Protsenko 4 months ago
parent ebbda00894
commit 57e2d3d939

4
.gitattributes vendored

@ -0,0 +1,4 @@
*.png filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text

167
.gitignore vendored

@ -0,0 +1,167 @@
tmp/
.ipynb_checkpoints
.Trash-1000
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -0,0 +1,34 @@
!!python/object:__main__.
batch_size: 152
class_slots: 16
classes: 16
dataset_name: quickdraw
experiment_dir: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
image_size: 28
kernel_size_pixels: 28
layers: 1
loss_plot_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05_loss.png
max_passes_through_dataset: 50
metric: 0.001
mlp_layers: 2
model_class: !!python/name:models.OpticalSystemMLP ''
model_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05.pt
name_id: OpticalSystemMLP_quickdraw_1_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
pixel_size_meters: 3.6e-05
propagation_distance: 300
resolution_scale_factor: 2
test_batch_size: 64
test_class_instances: 100
test_data_path: ./assets/quickdraw16_test.npy
tile_size_scale_factor: 2
train_class_instances: 8000
train_data_path: ./assets/quickdraw16_train.npy
wavelength: 5.32e-07

@ -0,0 +1,34 @@
!!python/object:__main__.
batch_size: 152
class_slots: 16
classes: 16
dataset_name: quickdraw
experiment_dir: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
image_size: 28
kernel_size_pixels: 28
layers: 4
loss_plot_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05_loss.png
max_passes_through_dataset: 50
metric: 0.001
mlp_layers: 2
model_class: !!python/name:models.OpticalSystemMLP ''
model_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05.pt
name_id: OpticalSystemMLP_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
pixel_size_meters: 3.6e-05
propagation_distance: 300
resolution_scale_factor: 2
test_batch_size: 64
test_class_instances: 100
test_data_path: ./assets/quickdraw16_test.npy
tile_size_scale_factor: 2
train_class_instances: 8000
train_data_path: ./assets/quickdraw16_train.npy
wavelength: 5.32e-07

@ -0,0 +1,34 @@
!!python/object:__main__.
batch_size: 152
class_slots: 16
classes: 16
dataset_name: quickdraw
experiment_dir: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05
image_size: 28
kernel_size_pixels: 28
layers: 1
loss_plot_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05_loss.png
max_passes_through_dataset: 20
metric: 0.001
mlp_layers: 2
model_class: !!python/name:models.OpticalSystem ''
model_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05.pt
name_id: OpticalSystem_quickdraw_1_2_16_20_28_2_2_300_5.32e-07_0.001_3.6e-05
pixel_size_meters: 3.6e-05
propagation_distance: 300
resolution_scale_factor: 2
test_batch_size: 64
test_class_instances: 100
test_data_path: ./assets/quickdraw16_test.npy
tile_size_scale_factor: 2
train_class_instances: 8000
train_data_path: ./assets/quickdraw16_train.npy
wavelength: 5.32e-07

@ -0,0 +1,34 @@
!!python/object:__main__.
batch_size: 152
class_slots: 16
classes: 16
dataset_name: quickdraw
experiment_dir: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
image_size: 28
kernel_size_pixels: 28
layers: 4
loss_plot_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05_loss.png
max_passes_through_dataset: 50
metric: 0.001
mlp_layers: 2
model_class: !!python/name:models.OpticalSystem ''
model_path: !!python/object/apply:pathlib.PosixPath
- experiments
- OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
- OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05.pt
name_id: OpticalSystem_quickdraw_4_2_16_50_28_2_2_300_5.32e-07_0.001_3.6e-05
pixel_size_meters: 3.6e-05
propagation_distance: 300
resolution_scale_factor: 2
test_batch_size: 64
test_class_instances: 100
test_data_path: ./assets/quickdraw16_test.npy
tile_size_scale_factor: 2
train_class_instances: 8000
train_data_path: ./assets/quickdraw16_train.npy
wavelength: 5.32e-07

@ -0,0 +1,419 @@
import torch
from torch import nn
from utils import pad_zeros, unpad_zeros
from torchvision.transforms.functional import resize, InterpolationMode
from einops import rearrange
import numpy as np
import math
from pprint import pprint, pformat
class OpticalSystem(nn.Module):
def __init__(self,
layers,
kernel_size_pixels,
tile_size_scale_factor,
resolution_scale_factor,
class_slots,
classes,
wavelength = 532e-9,
# refractive_index = 1.5090,
propagation_distance = 300,
pixel_size_meters = 36e-6,
metric = 1e-3
):
""""""
super().__init__()
self.layers = layers
self.kernel_size_pixels = kernel_size_pixels
self.tile_size_scale_factor = tile_size_scale_factor
self.resolution_scale_factor = resolution_scale_factor
self.class_slots = class_slots
self.classes = classes
self.wavelength = wavelength
# self.refractive_index = refractive_index
self.propagation_distance = propagation_distance
self.pixel_size_meters = pixel_size_meters
self.metric = metric
assert(self.class_slots >= self.classes)
self.empty_class_slots = self.class_slots - self.classes
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
self.B = self.A*self.phase_mask_size/self.tile_size
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
vv = torch.arange(0, self.phase_mask_size)
vv = (-1)**vv
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
lambda1 = self.wavelength / self.metric
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
self.vv = nn.Parameter((self.a*self.b).float())
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
self.U.requires_grad = False
self.vv.requires_grad = False
self.coef.requires_grad = False
self.height_maps = []
for i in range(self.layers):
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
h = torch.rand_like(self.U)*1e-1
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
# torch.nn.init.uniform_(heights, a=-1, b=1)
self.height_maps.append(heights)
self.height_maps = torch.nn.ParameterList(self.height_maps)
def propagation(self, field, propagation_distance):
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
def opt_conv(self, inputs, heights):
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
result = result * heights
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
amplitude = torch.sqrt(result.real**2 + result.imag**2)
return amplitude
def forward(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
# 2
image = pad_zeros(
image,
size = (self.phase_mask_size,
self.phase_mask_size),
)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
# 5
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
return max_pool, convolved
def __repr__(self):
tmp = {}
for k,v in self.__dict__.items():
if not k[0] == '_':
tmp[k] = v
tmp.update(self.__dict__['_modules'])
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
return pformat(tmp, indent=2)
def forward_debug(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
debug_out = []
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
debug_out.append(image)
# 2
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
image = pad_zeros(
image,
size = (self.phase_mask_size ,
self.phase_mask_size ),
)
debug_out.append(image)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
debug_out.append(convolved)
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
debug_out.append(grid_to_depth)
# 5
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
debug_out.append(grid_to_depth)
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
debug_out.append(max_pool)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
debug_out.append(max_pool)
# 6
softmax = torch.nn.functional.softmax(max_pool, dim=1)
return softmax, convolved, debug_out
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
assert num_layers > 1
self.num_layers = num_layers
layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
for i in range(num_layers-2):
layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
layers.append(nn.Linear(hidden_dim, output_dim))
self.body = nn.Sequential(*layers)
def forward(self, x):
return self.body(x)
class OpticalSystemMLP(nn.Module):
def __init__(self,
layers,
mlp_layers,
kernel_size_pixels,
tile_size_scale_factor,
resolution_scale_factor,
class_slots,
classes,
wavelength = 532e-9,
# refractive_index = 1.5090,
propagation_distance = 300,
pixel_size_meters = 36e-6,
metric = 1e-3
):
""""""
super().__init__()
self.layers = layers
self.kernel_size_pixels = kernel_size_pixels
self.tile_size_scale_factor = tile_size_scale_factor
self.resolution_scale_factor = resolution_scale_factor
self.class_slots = class_slots
self.classes = classes
self.wavelength = wavelength
# self.refractive_index = refractive_index
self.propagation_distance = propagation_distance
self.pixel_size_meters = pixel_size_meters
self.metric = metric
assert(self.class_slots >= self.classes)
self.empty_class_slots = self.class_slots - self.classes
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
self.B = self.A*self.phase_mask_size/self.tile_size
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
vv = torch.arange(0, self.phase_mask_size)
vv = (-1)**vv
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
lambda1 = self.wavelength / self.metric
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
self.vv = nn.Parameter((self.a*self.b).float())
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
self.U.requires_grad = False
self.vv.requires_grad = False
self.coef.requires_grad = False
self.height_maps = []
for i in range(self.layers):
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
h = torch.rand_like(self.U)*1e-1
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
# torch.nn.init.uniform_(heights, a=-1, b=1)
self.height_maps.append(heights)
self.height_maps = torch.nn.ParameterList(self.height_maps)
self.mlp = MLP(
input_dim=(self.kernel_size_pixels*self.resolution_scale_factor)**2, #self.class_slots,
hidden_dim=self.kernel_size_pixels*self.resolution_scale_factor,
output_dim=1,
num_layers=mlp_layers
)
def propagation(self, field, propagation_distance):
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
def opt_conv(self, inputs, heights):
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
result = result * heights
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
amplitude = torch.sqrt(result.real**2 + result.imag**2)
return amplitude
def forward(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
# debug_out.append(image)
# 2
image = pad_zeros(
image,
size = (self.phase_mask_size,
self.phase_mask_size),
)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
# 5
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = rearrange(
grid_to_depth,
"b mn ht wt -> b mn (ht wt)",
ht = self.kernel_size_pixels*self.resolution_scale_factor,
wt = self.kernel_size_pixels*self.resolution_scale_factor,
mn = self.class_slots
)
scores = self.mlp(grid_to_depth).abs()
scores = scores/scores.max()
scores = scores.squeeze(-1)
return scores, convolved
def __repr__(self):
tmp = {}
for k,v in self.__dict__.items():
if not k[0] == '_':
tmp[k] = v
tmp.update(self.__dict__['_modules'])
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
return pformat(tmp, indent=2)
def forward_debug(self, image):
debug_out = []
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
debug_out.append(image)
# 2
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
image = pad_zeros(
image,
size = (self.phase_mask_size ,
self.phase_mask_size ),
)
debug_out.append(image)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
debug_out.append(convolved)
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
debug_out.append(grid_to_depth)
# 5
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
debug_out.append(grid_to_depth)
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
debug_out.append(max_pool)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
debug_out.append(max_pool)
return max_pool, convolved, debug_out

@ -1,40 +1,44 @@
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from datetime import datetime
def imshow(tensor, figsize=None, title="", **args): def imshow(tensor, figsize=None, title="", **args):
tensor = tensor.cpu().detach() if isinstance(tensor, torch.Tensor) else tensor
tensor = list(tensor) if isinstance(tensor, torch.nn.modules.container.ParameterList) else tensor
figsize = figsize if figsize else (13*0.8,5*0.8) figsize = figsize if figsize else (13*0.8,5*0.8)
if type(tensor) is list: if type(tensor) is list:
outs = []
for idx, el in enumerate(tensor): for idx, el in enumerate(tensor):
imshow(el, figsize=figsize, title=title, **args) f, ax = imshow(el, figsize=figsize, title=title, **args)
plt.suptitle("{} {}".format(idx, title)) plt.suptitle("{} {}".format(idx, title))
return outs.append([f, ax])
return outs
if len(tensor.shape)==4: if len(tensor.shape)==4:
outs = []
for idx, el in enumerate(torch.squeeze(tensor, dim=1)): for idx, el in enumerate(torch.squeeze(tensor, dim=1)):
imshow(el, figsize=figsize, title=title, **args) f, ax = imshow(el, figsize=figsize, title=title, **args)
plt.suptitle("{} {}".format(idx, title)) plt.suptitle("{} {}".format(idx, title))
return outs.append([f, ax])
return outs
print(type(tensor))
tensor = tensor.detach().cpu() if type(tensor) == torch.Tensor else tensor
if tensor.dtype == torch.complex64: if tensor.dtype == torch.complex64:
f, ax = plt.subplots(1, 5, figsize=figsize, gridspec_kw={'width_ratios': [46.5,3,1,46.5,3]}) f, ax = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [46.5,46.5]})
real_im = ax[0].imshow(tensor.real, **args) real_im = ax[0].imshow(tensor.real, **args)
imag_im = ax[3].imshow(tensor.imag, **args) imag_im = ax[1].imshow(tensor.imag, **args)
box = ax[1].get_position()
box.x0 = box.x0 - 0.02
box.x1 = box.x1 - 0.03
ax[1].set_position(box)
box = ax[4].get_position()
box.x0 = box.x0 - 0.02
box.x1 = box.x1 - 0.03
ax[4].set_position(box)
ax[0].set_title("real"); ax[0].set_title("real");
ax[3].set_title("imag"); ax[1].set_title("imag");
f.colorbar(real_im, ax[1]); divider = make_axes_locatable(ax[0])
f.colorbar(imag_im, ax[4]); cax = divider.append_axes("right", size="5%", pad=0.05)
f.colorbar(real_im, cax);
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
f.colorbar(imag_im, cax);
f.suptitle(title) f.suptitle(title)
ax[2].remove() f.tight_layout()
return f, ax return f, ax
else: else:
f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize) f, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [95,5]}, figsize=figsize)
@ -43,6 +47,7 @@ def imshow(tensor, figsize=None, title="", **args):
f.suptitle(title) f.suptitle(title)
return f, ax return f, ax
def perm_roll(im, axis, amount): def perm_roll(im, axis, amount):
permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0) permutation = torch.roll(torch.arange(im.shape[axis], device=im.device), amount, dims=0)
return torch.index_select(im, axis, permutation) return torch.index_select(im, axis, permutation)
@ -61,30 +66,19 @@ def shift_right(im):
def pad_zeros(input, size): def pad_zeros(input, size):
h, w = input.shape[-2:] h, w = input.shape[-2:]
th, tw = size th, tw = size
out = torch.zeros(input.shape[:-2] + size, device=input.device)
if len(input.shape) == 2:
gg = torch.zeros(size, device=input.device)
x, y = int(th/2 - h/2), int(tw/2 - w/2) x, y = int(th/2 - h/2), int(tw/2 - w/2)
gg[x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:] out[..., x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[..., :,:]
return out
if len(input.shape) == 4:
gg = torch.zeros(input.shape[:2] + size, device=input.device)
x, y = int(th/2 - h/2), int(tw/2 - w/2)
gg[:,:,x:int(th/2 + h/2),y:int(tw/2 + w/2)] = input[:,:,:,:]
return gg
def unpad_zeros(input, size): def unpad_zeros(input, size):
h, w = input.shape[-2:] h, w = input.shape[-2:]
th, tw = size th, tw = size
dx,dy = h-th, w-tw dx,dy = h-th, w-tw
return input[..., int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)]
if len(input.shape) == 2: def to_class_labels(softmax_distibutions):
gg = input[int(h/2 - th/2):int(th/2 + h/2), int(w/2 - tw/2):int(tw/2 + w/2)] return torch.argmax(softmax_distibutions, dim=1).cpu()
if len(input.shape) == 4:
gg = input[:,:,dx//2:dx//2+th, dy//2:dy//2+tw]
return gg
def circular_aperture(h, w, r=None, is_inv=False): def circular_aperture(h, w, r=None, is_inv=False):
if r is None: if r is None:
@ -96,6 +90,3 @@ def circular_aperture(h, w, r=None, is_inv=False):
else: else:
circle_aperture = torch.where(circle_dist<r, torch.ones_like(circle_dist), torch.zeros_like(circle_dist)) circle_aperture = torch.where(circle_dist<r, torch.ones_like(circle_dist), torch.zeros_like(circle_dist))
return circle_aperture return circle_aperture
def to_class_labels(softmax_distibutions):
return torch.argmax(softmax_distibutions, dim=1).cpu()
Loading…
Cancel
Save