You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
419 lines
18 KiB
Python
419 lines
18 KiB
Python
import torch
|
|
from torch import nn
|
|
from utils import pad_zeros, unpad_zeros
|
|
from torchvision.transforms.functional import resize, InterpolationMode
|
|
from einops import rearrange
|
|
import numpy as np
|
|
import math
|
|
from pprint import pprint, pformat
|
|
|
|
class OpticalSystem(nn.Module):
|
|
def __init__(self,
|
|
layers,
|
|
kernel_size_pixels,
|
|
tile_size_scale_factor,
|
|
resolution_scale_factor,
|
|
class_slots,
|
|
classes,
|
|
wavelength = 532e-9,
|
|
# refractive_index = 1.5090,
|
|
propagation_distance = 300,
|
|
pixel_size_meters = 36e-6,
|
|
metric = 1e-3
|
|
):
|
|
""""""
|
|
super().__init__()
|
|
self.layers = layers
|
|
self.kernel_size_pixels = kernel_size_pixels
|
|
self.tile_size_scale_factor = tile_size_scale_factor
|
|
self.resolution_scale_factor = resolution_scale_factor
|
|
self.class_slots = class_slots
|
|
self.classes = classes
|
|
self.wavelength = wavelength
|
|
# self.refractive_index = refractive_index
|
|
self.propagation_distance = propagation_distance
|
|
self.pixel_size_meters = pixel_size_meters
|
|
self.metric = metric
|
|
|
|
assert(self.class_slots >= self.classes)
|
|
self.empty_class_slots = self.class_slots - self.classes
|
|
|
|
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
|
|
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
|
|
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
|
|
|
|
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
|
|
self.B = self.A*self.phase_mask_size/self.tile_size
|
|
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
|
|
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
|
|
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
|
|
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
|
|
|
|
vv = torch.arange(0, self.phase_mask_size)
|
|
vv = (-1)**vv
|
|
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
|
|
lambda1 = self.wavelength / self.metric
|
|
|
|
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
|
|
self.vv = nn.Parameter((self.a*self.b).float())
|
|
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
|
|
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
|
|
|
|
self.U.requires_grad = False
|
|
self.vv.requires_grad = False
|
|
self.coef.requires_grad = False
|
|
|
|
self.height_maps = []
|
|
for i in range(self.layers):
|
|
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
|
|
h = torch.rand_like(self.U)*1e-1
|
|
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
|
|
# torch.nn.init.uniform_(heights, a=-1, b=1)
|
|
self.height_maps.append(heights)
|
|
self.height_maps = torch.nn.ParameterList(self.height_maps)
|
|
|
|
|
|
def propagation(self, field, propagation_distance):
|
|
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
|
|
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
|
|
|
|
def opt_conv(self, inputs, heights):
|
|
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
|
|
result = result * heights
|
|
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
|
|
amplitude = torch.sqrt(result.real**2 + result.imag**2)
|
|
return amplitude
|
|
|
|
def forward(self, image):
|
|
"""
|
|
Алгоритм:
|
|
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
|
|
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
|
|
3. Моделируется прохождение света через транспаранты
|
|
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
|
|
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
|
|
"""
|
|
# 1
|
|
image = resize(
|
|
image,
|
|
size=(image.shape[-2]*self.resolution_scale_factor,
|
|
image.shape[-1]*self.resolution_scale_factor),
|
|
interpolation=InterpolationMode.NEAREST
|
|
)
|
|
# 2
|
|
image = pad_zeros(
|
|
image,
|
|
size = (self.phase_mask_size,
|
|
self.phase_mask_size),
|
|
)
|
|
# 3
|
|
x = image
|
|
for i, plate_heights in enumerate(self.height_maps):
|
|
x = self.opt_conv(x, plate_heights)
|
|
convolved = x
|
|
# 4
|
|
grid_to_depth = rearrange(
|
|
convolved,
|
|
"b 1 (m ht) (n wt) -> b (m n) ht wt",
|
|
ht = self.tile_size*self.resolution_scale_factor,
|
|
wt = self.tile_size*self.resolution_scale_factor,
|
|
m = self.tiles_per_dim,
|
|
n = self.tiles_per_dim
|
|
)
|
|
# 5
|
|
grid_to_depth = unpad_zeros(grid_to_depth,
|
|
(self.kernel_size_pixels*self.resolution_scale_factor,
|
|
self.kernel_size_pixels*self.resolution_scale_factor))
|
|
max_pool = torch.nn.functional.max_pool2d(
|
|
grid_to_depth,
|
|
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
|
|
)
|
|
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
|
|
max_pool /= max_pool.max(dim=1, keepdims=True).values
|
|
|
|
return max_pool, convolved
|
|
|
|
def __repr__(self):
|
|
tmp = {}
|
|
for k,v in self.__dict__.items():
|
|
if not k[0] == '_':
|
|
tmp[k] = v
|
|
tmp.update(self.__dict__['_modules'])
|
|
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
|
|
return pformat(tmp, indent=2)
|
|
|
|
def forward_debug(self, image):
|
|
"""
|
|
Алгоритм:
|
|
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
|
|
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
|
|
3. Моделируется прохождение света через транспаранты
|
|
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
|
|
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
|
|
"""
|
|
debug_out = []
|
|
# 1
|
|
image = resize(
|
|
image,
|
|
size=(image.shape[-2]*self.resolution_scale_factor,
|
|
image.shape[-1]*self.resolution_scale_factor),
|
|
interpolation=InterpolationMode.NEAREST
|
|
)
|
|
debug_out.append(image)
|
|
# 2
|
|
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
|
|
|
|
image = pad_zeros(
|
|
image,
|
|
size = (self.phase_mask_size ,
|
|
self.phase_mask_size ),
|
|
)
|
|
debug_out.append(image)
|
|
# 3
|
|
x = image
|
|
for i, plate_heights in enumerate(self.height_maps):
|
|
x = self.opt_conv(x, plate_heights)
|
|
convolved = x
|
|
debug_out.append(convolved)
|
|
# 4
|
|
grid_to_depth = rearrange(
|
|
convolved,
|
|
"b 1 (m ht) (n wt) -> b (m n) ht wt",
|
|
ht = self.tile_size*self.resolution_scale_factor,
|
|
wt = self.tile_size*self.resolution_scale_factor,
|
|
m = self.tiles_per_dim,
|
|
n = self.tiles_per_dim
|
|
)
|
|
debug_out.append(grid_to_depth)
|
|
# 5
|
|
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
|
|
grid_to_depth = unpad_zeros(grid_to_depth,
|
|
(self.kernel_size_pixels*self.resolution_scale_factor,
|
|
self.kernel_size_pixels*self.resolution_scale_factor))
|
|
debug_out.append(grid_to_depth)
|
|
max_pool = torch.nn.functional.max_pool2d(
|
|
grid_to_depth,
|
|
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
|
|
)
|
|
debug_out.append(max_pool)
|
|
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
|
|
max_pool /= max_pool.max(dim=1, keepdims=True).values
|
|
debug_out.append(max_pool)
|
|
# 6
|
|
softmax = torch.nn.functional.softmax(max_pool, dim=1)
|
|
return softmax, convolved, debug_out
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
super().__init__()
|
|
assert num_layers > 1
|
|
self.num_layers = num_layers
|
|
|
|
layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
|
|
for i in range(num_layers-2):
|
|
layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
|
|
layers.append(nn.Linear(hidden_dim, output_dim))
|
|
|
|
self.body = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.body(x)
|
|
|
|
class OpticalSystemMLP(nn.Module):
|
|
def __init__(self,
|
|
layers,
|
|
mlp_layers,
|
|
kernel_size_pixels,
|
|
tile_size_scale_factor,
|
|
resolution_scale_factor,
|
|
class_slots,
|
|
classes,
|
|
wavelength = 532e-9,
|
|
# refractive_index = 1.5090,
|
|
propagation_distance = 300,
|
|
pixel_size_meters = 36e-6,
|
|
metric = 1e-3
|
|
):
|
|
""""""
|
|
super().__init__()
|
|
self.layers = layers
|
|
self.kernel_size_pixels = kernel_size_pixels
|
|
self.tile_size_scale_factor = tile_size_scale_factor
|
|
self.resolution_scale_factor = resolution_scale_factor
|
|
self.class_slots = class_slots
|
|
self.classes = classes
|
|
self.wavelength = wavelength
|
|
# self.refractive_index = refractive_index
|
|
self.propagation_distance = propagation_distance
|
|
self.pixel_size_meters = pixel_size_meters
|
|
self.metric = metric
|
|
|
|
assert(self.class_slots >= self.classes)
|
|
self.empty_class_slots = self.class_slots - self.classes
|
|
|
|
self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor
|
|
self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32)
|
|
self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor
|
|
|
|
self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric
|
|
self.B = self.A*self.phase_mask_size/self.tile_size
|
|
x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1]
|
|
kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1]
|
|
self.x, self.y = torch.meshgrid(x, x, indexing='ij')
|
|
self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij')
|
|
|
|
vv = torch.arange(0, self.phase_mask_size)
|
|
vv = (-1)**vv
|
|
self.a, self.b = torch.meshgrid(vv, vv, indexing='ij')
|
|
lambda1 = self.wavelength / self.metric
|
|
|
|
self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float())
|
|
self.vv = nn.Parameter((self.a*self.b).float())
|
|
self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1]))
|
|
self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k]))
|
|
|
|
self.U.requires_grad = False
|
|
self.vv.requires_grad = False
|
|
self.coef.requires_grad = False
|
|
|
|
self.height_maps = []
|
|
for i in range(self.layers):
|
|
# heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k))
|
|
h = torch.rand_like(self.U)*1e-1
|
|
heights = nn.Parameter(torch.complex(torch.ones_like(h),h))
|
|
# torch.nn.init.uniform_(heights, a=-1, b=1)
|
|
self.height_maps.append(heights)
|
|
self.height_maps = torch.nn.ParameterList(self.height_maps)
|
|
|
|
self.mlp = MLP(
|
|
input_dim=(self.kernel_size_pixels*self.resolution_scale_factor)**2, #self.class_slots,
|
|
hidden_dim=self.kernel_size_pixels*self.resolution_scale_factor,
|
|
output_dim=1,
|
|
num_layers=mlp_layers
|
|
)
|
|
|
|
def propagation(self, field, propagation_distance):
|
|
F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k)
|
|
return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv
|
|
|
|
def opt_conv(self, inputs, heights):
|
|
result = self.propagation(field=inputs, propagation_distance=self.propagation_distance)
|
|
result = result * heights
|
|
result = self.propagation(field=result, propagation_distance=self.propagation_distance)
|
|
amplitude = torch.sqrt(result.real**2 + result.imag**2)
|
|
return amplitude
|
|
|
|
def forward(self, image):
|
|
"""
|
|
Алгоритм:
|
|
1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56]
|
|
2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448]
|
|
3. Моделируется прохождение света через транспаранты
|
|
4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim
|
|
5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется
|
|
"""
|
|
# 1
|
|
image = resize(
|
|
image,
|
|
size=(image.shape[-2]*self.resolution_scale_factor,
|
|
image.shape[-1]*self.resolution_scale_factor),
|
|
interpolation=InterpolationMode.NEAREST
|
|
)
|
|
# debug_out.append(image)
|
|
# 2
|
|
image = pad_zeros(
|
|
image,
|
|
size = (self.phase_mask_size,
|
|
self.phase_mask_size),
|
|
)
|
|
# 3
|
|
x = image
|
|
for i, plate_heights in enumerate(self.height_maps):
|
|
x = self.opt_conv(x, plate_heights)
|
|
convolved = x
|
|
# 4
|
|
grid_to_depth = rearrange(
|
|
convolved,
|
|
"b 1 (m ht) (n wt) -> b (m n) ht wt",
|
|
ht = self.tile_size*self.resolution_scale_factor,
|
|
wt = self.tile_size*self.resolution_scale_factor,
|
|
m = self.tiles_per_dim,
|
|
n = self.tiles_per_dim
|
|
)
|
|
# 5
|
|
grid_to_depth = unpad_zeros(grid_to_depth,
|
|
(self.kernel_size_pixels*self.resolution_scale_factor,
|
|
self.kernel_size_pixels*self.resolution_scale_factor))
|
|
grid_to_depth = rearrange(
|
|
grid_to_depth,
|
|
"b mn ht wt -> b mn (ht wt)",
|
|
ht = self.kernel_size_pixels*self.resolution_scale_factor,
|
|
wt = self.kernel_size_pixels*self.resolution_scale_factor,
|
|
mn = self.class_slots
|
|
)
|
|
scores = self.mlp(grid_to_depth).abs()
|
|
scores = scores/scores.max()
|
|
scores = scores.squeeze(-1)
|
|
|
|
return scores, convolved
|
|
|
|
def __repr__(self):
|
|
tmp = {}
|
|
for k,v in self.__dict__.items():
|
|
if not k[0] == '_':
|
|
tmp[k] = v
|
|
tmp.update(self.__dict__['_modules'])
|
|
tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()})
|
|
return pformat(tmp, indent=2)
|
|
|
|
def forward_debug(self, image):
|
|
debug_out = []
|
|
# 1
|
|
image = resize(
|
|
image,
|
|
size=(image.shape[-2]*self.resolution_scale_factor,
|
|
image.shape[-1]*self.resolution_scale_factor),
|
|
interpolation=InterpolationMode.NEAREST
|
|
)
|
|
debug_out.append(image)
|
|
# 2
|
|
print(image.shape, (self.phase_mask_size, self.phase_mask_size ))
|
|
|
|
image = pad_zeros(
|
|
image,
|
|
size = (self.phase_mask_size ,
|
|
self.phase_mask_size ),
|
|
)
|
|
debug_out.append(image)
|
|
# 3
|
|
x = image
|
|
for i, plate_heights in enumerate(self.height_maps):
|
|
x = self.opt_conv(x, plate_heights)
|
|
convolved = x
|
|
debug_out.append(convolved)
|
|
# 4
|
|
grid_to_depth = rearrange(
|
|
convolved,
|
|
"b 1 (m ht) (n wt) -> b (m n) ht wt",
|
|
ht = self.tile_size*self.resolution_scale_factor,
|
|
wt = self.tile_size*self.resolution_scale_factor,
|
|
m = self.tiles_per_dim,
|
|
n = self.tiles_per_dim
|
|
)
|
|
debug_out.append(grid_to_depth)
|
|
# 5
|
|
print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor))
|
|
grid_to_depth = unpad_zeros(grid_to_depth,
|
|
(self.kernel_size_pixels*self.resolution_scale_factor,
|
|
self.kernel_size_pixels*self.resolution_scale_factor))
|
|
debug_out.append(grid_to_depth)
|
|
max_pool = torch.nn.functional.max_pool2d(
|
|
grid_to_depth,
|
|
kernel_size = self.kernel_size_pixels*self.resolution_scale_factor
|
|
)
|
|
debug_out.append(max_pool)
|
|
max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots)
|
|
max_pool /= max_pool.max(dim=1, keepdims=True).values
|
|
debug_out.append(max_pool)
|
|
return max_pool, convolved, debug_out |