You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

379 lines
18 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch as _torch
import torch.nn as _nn
import numpy as _np
from scipy.special import fresnel as _fresnel
from .config import ConfigOpticBase as _ConfigOpticBase, ConfigDesignPlane as _ConfigDesignPlane
from typing import Tuple as _Tuple, Sequence as _Sequence
from abc import ABC as _ABC
import collections as _collections
import copy as _copy
class Propagator(_ABC, _nn.Module):
"""
Абстрактный класс вычисления распространения светового поля в среде.
Поля:
operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс
operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат
"""
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor):
super(Propagator, self).__init__()
operator_X: _torch.Tensor = _torch.view_as_real(operator_X)
operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.view_as_complex(self._operator_X)
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
def __operator_multiplication(self, first_X: _torch.Tensor,
second_X: _torch.Tensor,
first_Y: _torch.Tensor,
second_Y: _torch.Tensor)-> _Tuple[_torch.Tensor, _torch.Tensor]:
operator_Y = second_Y @ first_Y
operator_X = first_X @ second_X
return operator_X, operator_Y
def cat(self, propagators: _Sequence['Propagator']) -> 'Propagator':
"""
Метод схлопывания операторов распространения.
Args:
propagators: последовательность для схлопывания
Returns:
новый пропогатор, заменяющих собой серию предыдущих
Warning:
порядок расположения пропагаторов в последовательности важен,
идёт от первого к последниму
"""
operator_X: _torch.Tensor
operator_Y: _torch.Tensor
if not isinstance(propagators, _collections.abc.Sequence):
operator_X, operator_Y = self.__operator_multiplication(self.operator_X,
propagators.operator_X,
self.operator_Y,
propagators.operator_Y)
else:
size = len(propagators)
operator_X = self.operator_X
operator_Y = self.operator_Y
for i in range(size):
operator_X, operator_Y = self.__operator_multiplication(operator_X,
propagators[i].operator_X,
operator_Y,
propagators[i].operator_Y)
return Propagator(operator_X, operator_Y)
def __add__(self, propagator: 'Propagator') -> 'Propagator':
"""
Метод схлопывания двух пропагаторов.
Args:
propagator: пропагатор с которым нужно произвести схлопывание
Returns:
новый пропогатор, заменяющих собой оба предыдущих
Warning:
операция не комутативная
"""
return self.cat(propagator)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X
def __repr__(self):
return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
class PropagatorLens(Propagator):
"""
Абстрактный класс распространения света в тонком оптическом элементе.
"""
def transpose(self) -> 'PropagatorLens':
"""
Метод транспонирования тонкого оптического элемента.
Returns:
Новый элемент, транспонированный относительно оригинального.
"""
obj = Propagator.__new__(PropagatorLens)
Propagator.__init__(obj, self.operator_Y, self.operator_X)
return obj
@property
def T(self) -> 'PropagatorLens':
"""
Returns:
Новый элемент, транспонированный относительно текущего.
"""
return self.transpose()
class PropagatorCrossLens(PropagatorLens):
"""
Класс распространения света в скрещенной линзе,
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса скрещенной линзы.
Args:
plane: данные о расчётной плоскости элемента.
config: данные о световом поле модели.
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2)
super(PropagatorCrossLens, self).__init__(
_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorCylindLens(PropagatorLens):
"""
Класс распространения света в цилиндрической линзе,
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса цилиндрической линзы.
Args:
plane: данные о расчётной плоскости элемента.
config: данные о световом поле модели.
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
super(PropagatorCylindLens, self).__init__(
_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorSinc(Propagator):
"""
Класс распространения света свободном пространстве
с использованием разложения по базисным sinc функциям.
"""
def __init__(self, first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса распространения в свободном пространстве.
Args:
first_plane: данные о начальной расчётной плоскости.
second_plane: данные о конечной расчётной плоскости.
config: данные о световом поле модели.
"""
operator_X, operator_Y = self.__get_operators(first_plane,
second_plane,
config)
super(PropagatorSinc, self).__init__(operator_X, operator_Y)
def __get_operator_for_dim(self,
pixel_size_in: float,
pixel_size_out: float,
difference: float,
config: _ConfigOpticBase) -> _torch.Tensor:
bndW = 0.5 / pixel_size_in
eikz = (_np.exp(1j * config.K * config.distance)**0.5)
sq2p = (2 / _np.pi)**0.5
sqzk = ((2 * config.distance / config.K)**0.5)
mu1 = -_np.pi * sqzk * bndW - difference / sqzk
mu2 = _np.pi * sqzk * bndW - difference / sqzk
S1, C1 = _fresnel(mu1 * sq2p)
S2, C2 = _fresnel(mu2 * sq2p)
return (((pixel_size_in * pixel_size_out)**0.5 / _np.pi) / sqzk * eikz
* _np.exp(0.5j * difference**2 * config.K / config.distance)
* (C2 - C1 - 1j * (S2 - S1)) / sq2p)
def __get_operators(self,
first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase) -> _Tuple[_torch.Tensor, _torch.Tensor]:
difference_x = first_plane.linspace_by_x[None, :] - second_plane.linspace_by_x[:, None]
difference_y = first_plane.linspace_by_y[None, :] - second_plane.linspace_by_y[:, None]
operator_X = self.__get_operator_for_dim(first_plane.pixel_size_by_x,
second_plane.pixel_size_by_x,
difference_x,
config).transpose(-2, -1)
operator_Y = self.__get_operator_for_dim(first_plane.pixel_size_by_y,
second_plane.pixel_size_by_y,
difference_y,
config)
return operator_X, operator_Y
#######################################################################################################################
class PropagatorTrainableCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
# non smooth profile after training. better to train only focal length?
self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X
class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module):
"""
Класс распространения света в обучаемой цилиндрической линзе,
представленной тонким прозрачным оптическим элементом.
"""
def __init__(self,
plane: _ConfigDesignPlane,
config: _ConfigOpticBase
):
super().__init__()
self._distance = _nn.Parameter(_torch.tensor(config.distance))
self.register_buffer('_K', _torch.tensor(config.K), persistent=True)
self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True)
operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat))
operator_Y = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2))
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
@staticmethod
def __slice_calculation(total_rows: int, num_to_take: int) -> slice:
start = (total_rows - num_to_take) // 2
end = start + num_to_take
return slice(start, end)
def forward(self,
field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if (resul_shape is not None):
field_shape = field.shape[-2:]
operator_Y_shape = self.operator_Y.shape[-2:]
operator_X_shape = self.operator_X.shape[-2:]
slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0])
slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0])
slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1])
slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1])
return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four]
return self.operator_Y @ field @ self.operator_X