You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
18 KiB
Python

import torch
from torch import nn
from utils import pad_zeros, unpad_zeros
from torchvision.transforms.functional import resize, InterpolationMode
from einops import rearrange
import numpy as np
import math
from pprint import pprint, pformat
class OpticalSystem(nn.Module):
def __init__(self,
layers,
kernel_size_pixels,
tile_size_scale_factor,
resolution_scale_factor,
class_slots,
classes,
wavelength = 532e-9,
# refractive_index = 1.5090,
propagation_distance = 300,
pixel_size_meters = 36e-6,
metric = 1e-3
):
""""""
super().__init__()
self.layers = layers
self.kernel_size_pixels = kernel_size_pixels
self.tile_size_scale_factor = tile_size_scale_factor
self.resolution_scale_factor = resolution_scale_factor
self.class_slots = class_slots
self.classes = classes
self.wavelength = wavelength
# self.refractive_index = refractive_index
self.propagation_distance = propagation_distance
self.pixel_size_meters = pixel_size_meters
self.metric = metric
assert(self.class_slots >= self.classes)
self.empty_class_slots = self.class_slots - self.classes
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
self.B = self.A*self.phase_mask_size/self.tile_size
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
vv = torch.arange(0, self.phase_mask_size)
vv = (-1)**vv
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
lambda1 = self.wavelength / self.metric
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
self.vv = nn.Parameter((self.a*self.b).float())
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
self.U.requires_grad = False
self.vv.requires_grad = False
self.coef.requires_grad = False
self.height_maps = []
for i in range(self.layers):
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
h = torch.rand_like(self.U)*1e-1
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
# torch.nn.init.uniform_(heights, a=-1, b=1)
self.height_maps.append(heights)
self.height_maps = torch.nn.ParameterList(self.height_maps)
def propagation(self, field, propagation_distance):
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
def opt_conv(self, inputs, heights):
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
result = result * heights
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
amplitude = torch.sqrt(result.real**2 + result.imag**2)
return amplitude
def forward(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
# 2
image = pad_zeros(
image,
size = (self.phase_mask_size,
self.phase_mask_size),
)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
# 5
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
return max_pool, convolved
def __repr__(self):
tmp = {}
for k,v in self.__dict__.items():
if not k[0] == '_':
tmp[k] = v
tmp.update(self.__dict__['_modules'])
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
return pformat(tmp, indent=2)
def forward_debug(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
debug_out = []
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
debug_out.append(image)
# 2
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
image = pad_zeros(
image,
size = (self.phase_mask_size ,
self.phase_mask_size ),
)
debug_out.append(image)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
debug_out.append(convolved)
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
debug_out.append(grid_to_depth)
# 5
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
debug_out.append(grid_to_depth)
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
debug_out.append(max_pool)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
debug_out.append(max_pool)
# 6
softmax = torch.nn.functional.softmax(max_pool, dim=1)
return softmax, convolved, debug_out
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
assert num_layers > 1
self.num_layers = num_layers
layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
for i in range(num_layers-2):
layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
layers.append(nn.Linear(hidden_dim, output_dim))
self.body = nn.Sequential(*layers)
def forward(self, x):
return self.body(x)
class OpticalSystemMLP(nn.Module):
def __init__(self,
layers,
mlp_layers,
kernel_size_pixels,
tile_size_scale_factor,
resolution_scale_factor,
class_slots,
classes,
wavelength = 532e-9,
# refractive_index = 1.5090,
propagation_distance = 300,
pixel_size_meters = 36e-6,
metric = 1e-3
):
""""""
super().__init__()
self.layers = layers
self.kernel_size_pixels = kernel_size_pixels
self.tile_size_scale_factor = tile_size_scale_factor
self.resolution_scale_factor = resolution_scale_factor
self.class_slots = class_slots
self.classes = classes
self.wavelength = wavelength
# self.refractive_index = refractive_index
self.propagation_distance = propagation_distance
self.pixel_size_meters = pixel_size_meters
self.metric = metric
assert(self.class_slots >= self.classes)
self.empty_class_slots = self.class_slots - self.classes
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
self.B = self.A*self.phase_mask_size/self.tile_size
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
vv = torch.arange(0, self.phase_mask_size)
vv = (-1)**vv
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
lambda1 = self.wavelength / self.metric
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
self.vv = nn.Parameter((self.a*self.b).float())
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
self.U.requires_grad = False
self.vv.requires_grad = False
self.coef.requires_grad = False
self.height_maps = []
for i in range(self.layers):
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
h = torch.rand_like(self.U)*1e-1
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
# torch.nn.init.uniform_(heights, a=-1, b=1)
self.height_maps.append(heights)
self.height_maps = torch.nn.ParameterList(self.height_maps)
self.mlp = MLP(
input_dim=(self.kernel_size_pixels*self.resolution_scale_factor)**2, #self.class_slots,
hidden_dim=self.kernel_size_pixels*self.resolution_scale_factor,
output_dim=1,
num_layers=mlp_layers
)
def propagation(self, field, propagation_distance):
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
def opt_conv(self, inputs, heights):
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
result = result * heights
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
amplitude = torch.sqrt(result.real**2 + result.imag**2)
return amplitude
def forward(self, image):
"""
Алгоритм:
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
3. Моделируется прохождение света через транспаранты
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
"""
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
# debug_out.append(image)
# 2
image = pad_zeros(
image,
size = (self.phase_mask_size,
self.phase_mask_size),
)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
# 5
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = rearrange(
grid_to_depth,
"b mn ht wt -> b mn (ht wt)",
ht = self.kernel_size_pixels*self.resolution_scale_factor,
wt = self.kernel_size_pixels*self.resolution_scale_factor,
mn = self.class_slots
)
scores = self.mlp(grid_to_depth).abs()
scores = scores/scores.max()
scores = scores.squeeze(-1)
return scores, convolved
def __repr__(self):
tmp = {}
for k,v in self.__dict__.items():
if not k[0] == '_':
tmp[k] = v
tmp.update(self.__dict__['_modules'])
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
return pformat(tmp, indent=2)
def forward_debug(self, image):
debug_out = []
# 1
image = resize(
image,
size=(image.shape[-2]*self.resolution_scale_factor,
image.shape[-1]*self.resolution_scale_factor),
interpolation=InterpolationMode.NEAREST
)
debug_out.append(image)
# 2
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
image = pad_zeros(
image,
size = (self.phase_mask_size ,
self.phase_mask_size ),
)
debug_out.append(image)
# 3
x = image
for i, plate_heights in enumerate(self.height_maps):
x = self.opt_conv(x, plate_heights)
convolved = x
debug_out.append(convolved)
# 4
grid_to_depth = rearrange(
convolved,
"b 1 (m ht) (n wt) -> b (m n) ht wt",
ht = self.tile_size*self.resolution_scale_factor,
wt = self.tile_size*self.resolution_scale_factor,
m = self.tiles_per_dim,
n = self.tiles_per_dim
)
debug_out.append(grid_to_depth)
# 5
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
grid_to_depth = unpad_zeros(grid_to_depth,
(self.kernel_size_pixels*self.resolution_scale_factor,
self.kernel_size_pixels*self.resolution_scale_factor))
debug_out.append(grid_to_depth)
max_pool = torch.nn.functional.max_pool2d(
grid_to_depth,
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
)
debug_out.append(max_pool)
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
max_pool /= max_pool.max(dim=1, keepdims=True).values
debug_out.append(max_pool)
return max_pool, convolved, debug_out