import torch as _torch import torch.nn as _nn from .config import Config as _Config from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop class OpticalMul(_nn.Module): """ Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. """ def __init__(self, config: _Config): """ Конструктор класса. Args: config: конфигурация расчётной системы. """ super(OpticalMul, self).__init__() prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) prop_four = _PropСylindLens(config.matrix_plane, config) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: """ Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. Args: data: матрица комплексной амплитуды распределений световых полей. Returns: Матрицы содержащие вектора левой матрицы. """ data = data.cfloat().flip(-1) data = data.unsqueeze(-2) data = _torch.kron(data.contiguous(), self._kron_vec_utils) return data def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: """ Метод подготовки правой матрицы к подаче на вход системы. Args: data: матрица комплексной амплитуды распределения светового поля. Returns: Матрица - оптический элемент в центре модели. """ if (data.dim() > 4) and data.size(-1) == 2: data = _torch.view_as_complex(data) data = data.cfloat().transpose(-2, -1) data = data.unsqueeze(-3) data = _torch.kron(data.contiguous(), self._kron_mat_utils) return data def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: """ Метод получения результата матричного умножения. Args: data: матрицы выходого распределения светового поля системы. Returns: Вектор столбец (амплитудное распределение). """ ### Закоментированная часть кода - более физически корректный вариант работы модели, ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения field = field.abs().squeeze(-1) #**2 field = self._avg_pool(field) return field.flip(-1) #**0.5 def forward(self, input: _torch.Tensor, other: _torch.Tensor) -> _torch.Tensor: """ Метод выполения матричного умножения. Args: input: матрица (B, C, H, W). other: матрица (B, C, W, K). Returns: Рензультат матричного умножения (B, C, H, K). Example: >>> mul = OpticalMul(...) >>> A = torch.rand((1, 1, 256, 256)) > 0.5 >>> B = torch.rand((1, 1, 256, 256)) > 0.5 >>> mul(A, B).shape torch.Size([1, 1, 256, 256]) >>> A = torch.rand((1, 1, 64, 256)) > 0.5 >>> B = torch.rand((1, 1, 256, 128)) > 0.5 >>> mul(A, B).shape torch.Size([1, 1, 64, 128]) """ vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) vec_field = self._propagator_one(vec_field) vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field)