diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index d7759d1..4c3ccb4 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -10,7 +10,6 @@ from torch.utils.tensorboard import SummaryWriter from typing import Optional import matplotlib.pyplot as plt - class OpticalMul(_nn.Module): """ Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. @@ -123,119 +122,6 @@ class OpticalMul(_nn.Module): return self.prepare_out(vec_field) -class TrainableScalarOpticalMul(_nn.Module): - """ - Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. - """ - def __init__(self, config: _Config): - """ - Конструктор класса. - - Args: - config: конфигурация расчётной системы. - """ - super().__init__() - - prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) - prop_two = _PropCrossLens(config.first_lens_plane, config) - prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropCylindLens(config.matrix_plane, config) - prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) - prop_six = _PropCrossLens(config.second_lens_plane, config).T - prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - - self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four - self._propagator_two: _Prop = prop_five + prop_six + prop_seven - - kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) - kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) - self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) - self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) - - self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) - self.k = nn.Parameter(_torch.tensor(1)) - - def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: - """ - Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. - - Args: - data: матрица комплексной амплитуды распределений световых полей. - - Returns: - Матрицы содержащие вектора левой матрицы. - """ - data = data.cfloat().flip(-1) - data = data.unsqueeze(-2) - data = _torch.kron(data.contiguous(), self._kron_vec_utils) - return data - - def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: - """ - Метод подготовки правой матрицы к подаче на вход системы. - - Args: - data: матрица комплексной амплитуды распределения светового поля. - - Returns: - Матрица - оптический элемент в центре модели. - """ - if (data.dim() > 4) and data.size(-1) == 2: - data = _torch.view_as_complex(data) - - data = data.cfloat().transpose(-2, -1) - data = data.unsqueeze(-3) - data = _torch.kron(data.contiguous(), self._kron_mat_utils) - return data - - def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: - """ - Метод получения результата матричного умножения. - - Args: - data: матрицы выходого распределения светового поля системы. - - Returns: - Вектор столбец (амплитудное распределение). - """ - ### Закоментированная часть кода - более физически корректный вариант работы модели, - ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения - field = field.abs().squeeze(-1) #**2 - field = self._avg_pool(field) - return field.flip(-1) #**0.5 - - def forward(self, - input: _torch.Tensor, - other: _torch.Tensor) -> _torch.Tensor: - """ - Метод выполения матричного умножения. - - Args: - input: матрица (B, C, H, W). - other: матрица (B, C, W, K). - - Returns: - Рензультат матричного умножения (B, C, H, K). - - Example: - >>> mul = OpticalMul(...) - >>> A = torch.rand((1, 1, 256, 256)) > 0.5 - >>> B = torch.rand((1, 1, 256, 256)) > 0.5 - >>> mul(A, B).shape - torch.Size([1, 1, 256, 256]) - >>> A = torch.rand((1, 1, 64, 256)) > 0.5 - >>> B = torch.rand((1, 1, 256, 128)) > 0.5 - >>> mul(A, B).shape - torch.Size([1, 1, 64, 128]) - """ - vec_field = self.prepare_vector(input) - mat_field = self.prepare_matrix(other) - - vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) - vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1)) - - return self.k * self.prepare_out(vec_field) - class TrainableLensOpticalMul(_nn.Module): """ Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. @@ -533,120 +419,4 @@ class TrainableFocalDistLensOpticalMul(_nn.Module): # Send the figure to the "Plots" or "Images" tab in TensorBoard writer.add_figure(f"{tag}/phase_profile", fig, global_step) - plt.close(fig) # Important: prevent memory leaks - - -class TrainableScalarAndLensOpticalMul(_nn.Module): - """ - Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. - """ - def __init__(self, config: _Config): - """ - Конструктор класса. - - Args: - config: конфигурация расчётной системы. - """ - super().__init__() - - prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) - prop_two = _PropCrossLens(config.first_lens_plane, config) - prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config) - prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) - prop_six = _PropCrossLens(config.second_lens_plane, config).T - prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - - self._propagator_one: _Prop = prop_one + prop_two + prop_three - self._propagator_two: _Prop = prop_four - self._propagator_three: _Prop = prop_five + prop_six + prop_seven - - kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) - kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) - self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) - self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) - - self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) - self.k = nn.Parameter(torch.tensor(1)) - - def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: - """ - Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. - - Args: - data: матрица комплексной амплитуды распределений световых полей. - - Returns: - Матрицы содержащие вектора левой матрицы. - """ - data = data.cfloat().flip(-1) - data = data.unsqueeze(-2) - data = _torch.kron(data.contiguous(), self._kron_vec_utils) - return data - - def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: - """ - Метод подготовки правой матрицы к подаче на вход системы. - - Args: - data: матрица комплексной амплитуды распределения светового поля. - - Returns: - Матрица - оптический элемент в центре модели. - """ - if (data.dim() > 4) and data.size(-1) == 2: - data = _torch.view_as_complex(data) - - data = data.cfloat().transpose(-2, -1) - data = data.unsqueeze(-3) - data = _torch.kron(data.contiguous(), self._kron_mat_utils) - return data - - def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: - """ - Метод получения результата матричного умножения. - - Args: - data: матрицы выходого распределения светового поля системы. - - Returns: - Вектор столбец (амплитудное распределение). - """ - ### Закоментированная часть кода - более физически корректный вариант работы модели, - ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения - field = field.abs().squeeze(-1) #**2 - field = self._avg_pool(field) - return field.flip(-1) #**0.5 - - def forward(self, - input: _torch.Tensor, - other: _torch.Tensor) -> _torch.Tensor: - """ - Метод выполения матричного умножения. - - Args: - input: матрица (B, C, H, W). - other: матрица (B, C, W, K). - - Returns: - Рензультат матричного умножения (B, C, H, K). - - Example: - >>> mul = OpticalMul(...) - >>> A = torch.rand((1, 1, 256, 256)) > 0.5 - >>> B = torch.rand((1, 1, 256, 256)) > 0.5 - >>> mul(A, B).shape - torch.Size([1, 1, 256, 256]) - >>> A = torch.rand((1, 1, 64, 256)) > 0.5 - >>> B = torch.rand((1, 1, 256, 128)) > 0.5 - >>> mul(A, B).shape - torch.Size([1, 1, 64, 128]) - """ - vec_field = self.prepare_vector(input) - mat_field = self.prepare_matrix(other) - - vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) - vec_field = self._propagator_two(vec_field, mat_field.shape[-2:]) - vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) - - return self.k * self.prepare_out(vec_field) \ No newline at end of file + plt.close(fig) # Important: prevent memory leaks \ No newline at end of file