import torch from torch import nn from utils import pad_zeros, unpad_zeros from torchvision.transforms.functional import resize, InterpolationMode from einops import rearrange import numpy as np import math from pprint import pprint, pformat class OpticalSystem(nn.Module): def __init__(self, layers, kernel_size_pixels, tile_size_scale_factor, resolution_scale_factor, class_slots, classes, wavelength = 532e-9, # refractive_index = 1.5090, propagation_distance = 300, pixel_size_meters = 36e-6, metric = 1e-3 ): """""" super().__init__() self.layers = layers self.kernel_size_pixels = kernel_size_pixels self.tile_size_scale_factor = tile_size_scale_factor self.resolution_scale_factor = resolution_scale_factor self.class_slots = class_slots self.classes = classes self.wavelength = wavelength # self.refractive_index = refractive_index self.propagation_distance = propagation_distance self.pixel_size_meters = pixel_size_meters self.metric = metric assert(self.class_slots >= self.classes) self.empty_class_slots = self.class_slots - self.classes self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32) self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric self.B = self.A*self.phase_mask_size/self.tile_size x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1] kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1] self.x, self.y = torch.meshgrid(x, x, indexing='ij') self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij') vv = torch.arange(0, self.phase_mask_size) vv = (-1)**vv self.a, self.b = torch.meshgrid(vv, vv, indexing='ij') lambda1 = self.wavelength / self.metric self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float()) self.vv = nn.Parameter((self.a*self.b).float()) self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1])) self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k])) self.U.requires_grad = False self.vv.requires_grad = False self.coef.requires_grad = False self.height_maps = [] for i in range(self.layers): # heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k)) h = torch.rand_like(self.U)*1e-1 heights = nn.Parameter(torch.complex(torch.ones_like(h),h)) # torch.nn.init.uniform_(heights, a=-1, b=1) self.height_maps.append(heights) self.height_maps = torch.nn.ParameterList(self.height_maps) def propagation(self, field, propagation_distance): F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k) return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv def opt_conv(self, inputs, heights): result = self.propagation(field=inputs, propagation_distance=self.propagation_distance) result = result * heights result = self.propagation(field=result, propagation_distance=self.propagation_distance) amplitude = torch.sqrt(result.real**2 + result.imag**2) return amplitude def forward(self, image): """ Алгоритм: 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56] 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448] 3. Моделируется прохождение света через транспаранты 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется """ # 1 image = resize( image, size=(image.shape[-2]*self.resolution_scale_factor, image.shape[-1]*self.resolution_scale_factor), interpolation=InterpolationMode.NEAREST ) # 2 image = pad_zeros( image, size = (self.phase_mask_size, self.phase_mask_size), ) # 3 x = image for i, plate_heights in enumerate(self.height_maps): x = self.opt_conv(x, plate_heights) convolved = x # 4 grid_to_depth = rearrange( convolved, "b 1 (m ht) (n wt) -> b (m n) ht wt", ht = self.tile_size*self.resolution_scale_factor, wt = self.tile_size*self.resolution_scale_factor, m = self.tiles_per_dim, n = self.tiles_per_dim ) # 5 grid_to_depth = unpad_zeros(grid_to_depth, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) max_pool = torch.nn.functional.max_pool2d( grid_to_depth, kernel_size = self.kernel_size_pixels*self.resolution_scale_factor ) max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots) max_pool /= max_pool.max(dim=1, keepdims=True).values return max_pool, convolved def __repr__(self): tmp = {} for k,v in self.__dict__.items(): if not k[0] == '_': tmp[k] = v tmp.update(self.__dict__['_modules']) tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()}) return pformat(tmp, indent=2) def forward_debug(self, image): """ Алгоритм: 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56] 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448] 3. Моделируется прохождение света через транспаранты 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется """ debug_out = [] # 1 image = resize( image, size=(image.shape[-2]*self.resolution_scale_factor, image.shape[-1]*self.resolution_scale_factor), interpolation=InterpolationMode.NEAREST ) debug_out.append(image) # 2 print(image.shape, (self.phase_mask_size, self.phase_mask_size )) image = pad_zeros( image, size = (self.phase_mask_size , self.phase_mask_size ), ) debug_out.append(image) # 3 x = image for i, plate_heights in enumerate(self.height_maps): x = self.opt_conv(x, plate_heights) convolved = x debug_out.append(convolved) # 4 grid_to_depth = rearrange( convolved, "b 1 (m ht) (n wt) -> b (m n) ht wt", ht = self.tile_size*self.resolution_scale_factor, wt = self.tile_size*self.resolution_scale_factor, m = self.tiles_per_dim, n = self.tiles_per_dim ) debug_out.append(grid_to_depth) # 5 print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) grid_to_depth = unpad_zeros(grid_to_depth, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) debug_out.append(grid_to_depth) max_pool = torch.nn.functional.max_pool2d( grid_to_depth, kernel_size = self.kernel_size_pixels*self.resolution_scale_factor ) debug_out.append(max_pool) max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots) max_pool /= max_pool.max(dim=1, keepdims=True).values debug_out.append(max_pool) # 6 softmax = torch.nn.functional.softmax(max_pool, dim=1) return softmax, convolved, debug_out class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() assert num_layers > 1 self.num_layers = num_layers layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()] for i in range(num_layers-2): layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()] layers.append(nn.Linear(hidden_dim, output_dim)) self.body = nn.Sequential(*layers) def forward(self, x): return self.body(x) class OpticalSystemMLP(nn.Module): def __init__(self, layers, mlp_layers, kernel_size_pixels, tile_size_scale_factor, resolution_scale_factor, class_slots, classes, wavelength = 532e-9, # refractive_index = 1.5090, propagation_distance = 300, pixel_size_meters = 36e-6, metric = 1e-3 ): """""" super().__init__() self.layers = layers self.kernel_size_pixels = kernel_size_pixels self.tile_size_scale_factor = tile_size_scale_factor self.resolution_scale_factor = resolution_scale_factor self.class_slots = class_slots self.classes = classes self.wavelength = wavelength # self.refractive_index = refractive_index self.propagation_distance = propagation_distance self.pixel_size_meters = pixel_size_meters self.metric = metric assert(self.class_slots >= self.classes) self.empty_class_slots = self.class_slots - self.classes self.tile_size = self.kernel_size_pixels * self.tile_size_scale_factor self.tiles_per_dim = np.ceil(np.sqrt(self.class_slots)).astype(np.int32) self.phase_mask_size = self.tile_size * self.tiles_per_dim * self.resolution_scale_factor self.A = self.pixel_size_meters*self.kernel_size_pixels/self.resolution_scale_factor/self.metric self.B = self.A*self.phase_mask_size/self.tile_size x = torch.linspace(-self.B, self.B, self.phase_mask_size+1)[:-1] kx = torch.linspace(-torch.pi*self.phase_mask_size/2/self.B, torch.pi*self.phase_mask_size/2/self.B, self.phase_mask_size+1)[:-1] self.x, self.y = torch.meshgrid(x, x, indexing='ij') self.Kx, self.Ky = torch.meshgrid(kx, kx, indexing='ij') vv = torch.arange(0, self.phase_mask_size) vv = (-1)**vv self.a, self.b = torch.meshgrid(vv, vv, indexing='ij') lambda1 = self.wavelength / self.metric self.U = nn.Parameter((self.Kx**2 + self.Ky**2).float()) self.vv = nn.Parameter((self.a*self.b).float()) self.k = nn.Parameter(torch.tensor([2*torch.pi/lambda1])) self.coef = nn.Parameter(torch.tensor([1j*self.propagation_distance*self.k])) self.U.requires_grad = False self.vv.requires_grad = False self.coef.requires_grad = False self.height_maps = [] for i in range(self.layers): # heights = nn.Parameter(torch.exp(-1j*(self.x**2 + self.y**2)/self.resolution_scale_factor/self.propagation_distance*self.k)) h = torch.rand_like(self.U)*1e-1 heights = nn.Parameter(torch.complex(torch.ones_like(h),h)) # torch.nn.init.uniform_(heights, a=-1, b=1) self.height_maps.append(heights) self.height_maps = torch.nn.ParameterList(self.height_maps) self.mlp = MLP( input_dim=(self.kernel_size_pixels*self.resolution_scale_factor)**2, #self.class_slots, hidden_dim=self.kernel_size_pixels*self.resolution_scale_factor, output_dim=1, num_layers=mlp_layers ) def propagation(self, field, propagation_distance): F = torch.exp(self.coef)*torch.exp(-1j*propagation_distance*self.U/self.resolution_scale_factor/self.k) return torch.fft.ifft2(torch.fft.fft2(field * self.vv) * F) * self.vv def opt_conv(self, inputs, heights): result = self.propagation(field=inputs, propagation_distance=self.propagation_distance) result = result * heights result = self.propagation(field=result, propagation_distance=self.propagation_distance) amplitude = torch.sqrt(result.real**2 + result.imag**2) return amplitude def forward(self, image): """ Алгоритм: 1. Входное изображение увеличивается в self.resolution_scale_factor. [28,28] -> [56,56] 2. Полученное изображение дополняется 0 до размера self.phase_mask_size. [56,56] -> [448, 448] 3. Моделируется прохождение света через транспаранты 4. Выходное изображение нарезается в набор областей self.tiles_per_dim x self.tiles_per_dim 5. Области преобразуются в вектор длины self.class_slots операцией max и затем нормируется """ # 1 image = resize( image, size=(image.shape[-2]*self.resolution_scale_factor, image.shape[-1]*self.resolution_scale_factor), interpolation=InterpolationMode.NEAREST ) # debug_out.append(image) # 2 image = pad_zeros( image, size = (self.phase_mask_size, self.phase_mask_size), ) # 3 x = image for i, plate_heights in enumerate(self.height_maps): x = self.opt_conv(x, plate_heights) convolved = x # 4 grid_to_depth = rearrange( convolved, "b 1 (m ht) (n wt) -> b (m n) ht wt", ht = self.tile_size*self.resolution_scale_factor, wt = self.tile_size*self.resolution_scale_factor, m = self.tiles_per_dim, n = self.tiles_per_dim ) # 5 grid_to_depth = unpad_zeros(grid_to_depth, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) grid_to_depth = rearrange( grid_to_depth, "b mn ht wt -> b mn (ht wt)", ht = self.kernel_size_pixels*self.resolution_scale_factor, wt = self.kernel_size_pixels*self.resolution_scale_factor, mn = self.class_slots ) scores = self.mlp(grid_to_depth).abs() scores = scores/scores.max() scores = scores.squeeze(-1) return scores, convolved def __repr__(self): tmp = {} for k,v in self.__dict__.items(): if not k[0] == '_': tmp[k] = v tmp.update(self.__dict__['_modules']) tmp.update({k:f"{v.dtype} {v.shape}" for k,v in self.__dict__['_parameters'].items()}) return pformat(tmp, indent=2) def forward_debug(self, image): debug_out = [] # 1 image = resize( image, size=(image.shape[-2]*self.resolution_scale_factor, image.shape[-1]*self.resolution_scale_factor), interpolation=InterpolationMode.NEAREST ) debug_out.append(image) # 2 print(image.shape, (self.phase_mask_size, self.phase_mask_size )) image = pad_zeros( image, size = (self.phase_mask_size , self.phase_mask_size ), ) debug_out.append(image) # 3 x = image for i, plate_heights in enumerate(self.height_maps): x = self.opt_conv(x, plate_heights) convolved = x debug_out.append(convolved) # 4 grid_to_depth = rearrange( convolved, "b 1 (m ht) (n wt) -> b (m n) ht wt", ht = self.tile_size*self.resolution_scale_factor, wt = self.tile_size*self.resolution_scale_factor, m = self.tiles_per_dim, n = self.tiles_per_dim ) debug_out.append(grid_to_depth) # 5 print(grid_to_depth.shape, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) grid_to_depth = unpad_zeros(grid_to_depth, (self.kernel_size_pixels*self.resolution_scale_factor, self.kernel_size_pixels*self.resolution_scale_factor)) debug_out.append(grid_to_depth) max_pool = torch.nn.functional.max_pool2d( grid_to_depth, kernel_size = self.kernel_size_pixels*self.resolution_scale_factor ) debug_out.append(max_pool) max_pool = rearrange(max_pool, "b class_slots 1 1 -> b class_slots", class_slots=self.class_slots) max_pool /= max_pool.max(dim=1, keepdims=True).values debug_out.append(max_pool) return max_pool, convolved, debug_out