import torch as _torch import torch.nn as _nn import numpy as _np from scipy.special import fresnel as _fresnel from .config import ConfigOpticBase as _ConfigOpticBase, ConfigDesignPlane as _ConfigDesignPlane from typing import Tuple as _Tuple, Sequence as _Sequence from abc import ABC as _ABC import collections as _collections import copy as _copy class Propagator(_ABC, _nn.Module): """ Абстрактный класс вычисления распространения светового поля в среде. Поля: operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) self.register_buffer('_operator_X', operator_X, persistent=True) self.register_buffer('_operator_Y', operator_Y, persistent=True) @property def operator_X(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси абсцисс """ return _torch.view_as_complex(self._operator_X) @property def operator_Y(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси ординат """ return _torch.view_as_complex(self._operator_Y) def __operator_multiplication(self, first_X: _torch.Tensor, second_X: _torch.Tensor, first_Y: _torch.Tensor, second_Y: _torch.Tensor)-> _Tuple[_torch.Tensor, _torch.Tensor]: operator_Y = second_Y @ first_Y operator_X = first_X @ second_X return operator_X, operator_Y def cat(self, propagators: _Sequence['Propagator']) -> 'Propagator': """ Метод схлопывания операторов распространения. Args: propagators: последовательность для схлопывания Returns: новый пропогатор, заменяющих собой серию предыдущих Warning: порядок расположения пропагаторов в последовательности важен, идёт от первого к последниму """ operator_X: _torch.Tensor operator_Y: _torch.Tensor if not isinstance(propagators, _collections.abc.Sequence): operator_X, operator_Y = self.__operator_multiplication(self.operator_X, propagators.operator_X, self.operator_Y, propagators.operator_Y) else: size = len(propagators) operator_X = self.operator_X operator_Y = self.operator_Y for i in range(size): operator_X, operator_Y = self.__operator_multiplication(operator_X, propagators[i].operator_X, operator_Y, propagators[i].operator_Y) return Propagator(operator_X, operator_Y) def __add__(self, propagator: 'Propagator') -> 'Propagator': """ Метод схлопывания двух пропагаторов. Args: propagator: пропагатор с которым нужно произвести схлопывание Returns: новый пропогатор, заменяющих собой оба предыдущих Warning: операция не комутативная """ return self.cat(propagator) @staticmethod def __slice_calculation(total_rows: int, num_to_take: int) -> slice: start = (total_rows - num_to_take) // 2 end = start + num_to_take return slice(start, end) def forward(self, field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: """ Метод распространения светового поля в среде. Args: field: распределение комплексной амплитуды светового поля. Returns: Распределение комплексной амплитуды светового поля, после распространения. """ if (resul_shape is not None): field_shape = field.shape[-2:] operator_Y_shape = self.operator_Y.shape[-2:] operator_X_shape = self.operator_X.shape[-2:] slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0]) slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0]) slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1]) slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1]) return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] return self.operator_Y @ field @ self.operator_X def __repr__(self): return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ Абстрактный класс распространения света в тонком оптическом элементе. """ def transpose(self) -> 'PropagatorLens': """ Метод транспонирования тонкого оптического элемента. Returns: Новый элемент, транспонированный относительно оригинального. """ obj = Propagator.__new__(PropagatorLens) Propagator.__init__(obj, self.operator_Y, self.operator_X) return obj @property def T(self) -> 'PropagatorLens': """ Returns: Новый элемент, транспонированный относительно текущего. """ return self.transpose() class PropagatorCrossLens(PropagatorLens): """ Класс распространения света в скрещенной линзе, представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, config: _ConfigOpticBase): """ Конструктор класса скрещенной линзы. Args: plane: данные о расчётной плоскости элемента. config: данные о световом поле модели. """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) super(PropagatorCrossLens, self).__init__( _torch.diag_embed(operator_X), _torch.diag_embed(operator_Y)) class PropagatorCylindLens(PropagatorLens): """ Класс распространения света в цилиндрической линзе, представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, config: _ConfigOpticBase): """ Конструктор класса цилиндрической линзы. Args: plane: данные о расчётной плоскости элемента. config: данные о световом поле модели. """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) super(PropagatorCylindLens, self).__init__( _torch.diag_embed(operator_X), _torch.diag_embed(operator_Y)) class PropagatorSinc(Propagator): """ Класс распространения света свободном пространстве с использованием разложения по базисным sinc функциям. """ def __init__(self, first_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane, config: _ConfigOpticBase): """ Конструктор класса распространения в свободном пространстве. Args: first_plane: данные о начальной расчётной плоскости. second_plane: данные о конечной расчётной плоскости. config: данные о световом поле модели. """ operator_X, operator_Y = self.__get_operators(first_plane, second_plane, config) super(PropagatorSinc, self).__init__(operator_X, operator_Y) def __get_operator_for_dim(self, pixel_size_in: float, pixel_size_out: float, difference: float, config: _ConfigOpticBase) -> _torch.Tensor: bndW = 0.5 / pixel_size_in eikz = (_np.exp(1j * config.K * config.distance)**0.5) sq2p = (2 / _np.pi)**0.5 sqzk = ((2 * config.distance / config.K)**0.5) mu1 = -_np.pi * sqzk * bndW - difference / sqzk mu2 = _np.pi * sqzk * bndW - difference / sqzk S1, C1 = _fresnel(mu1 * sq2p) S2, C2 = _fresnel(mu2 * sq2p) return (((pixel_size_in * pixel_size_out)**0.5 / _np.pi) / sqzk * eikz * _np.exp(0.5j * difference**2 * config.K / config.distance) * (C2 - C1 - 1j * (S2 - S1)) / sq2p) def __get_operators(self, first_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane, config: _ConfigOpticBase) -> _Tuple[_torch.Tensor, _torch.Tensor]: difference_x = first_plane.linspace_by_x[None, :] - second_plane.linspace_by_x[:, None] difference_y = first_plane.linspace_by_y[None, :] - second_plane.linspace_by_y[:, None] operator_X = self.__get_operator_for_dim(first_plane.pixel_size_by_x, second_plane.pixel_size_by_x, difference_x, config).transpose(-2, -1) operator_Y = self.__get_operator_for_dim(first_plane.pixel_size_by_y, second_plane.pixel_size_by_y, difference_y, config) return operator_X, operator_Y ####################################################################################################################### class PropagatorTrainableCylindLens(_ABC, _nn.Module): """ Класс распространения света в обучаемой цилиндрической линзе, представленной тонким прозрачным оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, config: _ConfigOpticBase ): super().__init__() # non smooth profile after training. better to train only focal length? self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) operator_Y = _torch.view_as_real(operator_Y) self.register_buffer('_operator_Y', operator_Y, persistent=True) @property def operator_X(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси абсцисс """ return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi)) @property def operator_Y(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси ординат """ return _torch.view_as_complex(self._operator_Y) @staticmethod def __slice_calculation(total_rows: int, num_to_take: int) -> slice: start = (total_rows - num_to_take) // 2 end = start + num_to_take return slice(start, end) def forward(self, field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: """ Метод распространения светового поля в среде. Args: field: распределение комплексной амплитуды светового поля. Returns: Распределение комплексной амплитуды светового поля, после распространения. """ if (resul_shape is not None): field_shape = field.shape[-2:] operator_Y_shape = self.operator_Y.shape[-2:] operator_X_shape = self.operator_X.shape[-2:] slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] return self.operator_Y @ field @ self.operator_X class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module): """ Класс распространения света в обучаемой цилиндрической линзе, представленной тонким прозрачным оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, config: _ConfigOpticBase ): super().__init__() self._distance = _nn.Parameter(_torch.tensor(config.distance)) self.register_buffer('_K', _torch.tensor(config.K), persistent=True) self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True) operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) operator_Y = _torch.view_as_real(operator_Y) self.register_buffer('_operator_Y', operator_Y, persistent=True) @property def operator_X(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси абсцисс """ return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2)) @property def operator_Y(self) -> _torch.Tensor: """ Returns: оператор отображающий распроcтранение светового поля вдоль оси ординат """ return _torch.view_as_complex(self._operator_Y) @staticmethod def __slice_calculation(total_rows: int, num_to_take: int) -> slice: start = (total_rows - num_to_take) // 2 end = start + num_to_take return slice(start, end) def forward(self, field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: """ Метод распространения светового поля в среде. Args: field: распределение комплексной амплитуды светового поля. Returns: Распределение комплексной амплитуды светового поля, после распространения. """ if (resul_shape is not None): field_shape = field.shape[-2:] operator_Y_shape = self.operator_Y.shape[-2:] operator_X_shape = self.operator_X.shape[-2:] slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] return self.operator_Y @ field @ self.operator_X