diff --git a/src/main.py b/src/main.py index 52fd104..d6a06c4 100644 --- a/src/main.py +++ b/src/main.py @@ -30,7 +30,7 @@ models = { } batch_size = 50 -gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = 40000 eval_interval = 300 learning_rate = 1e-3 @@ -38,7 +38,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 256 +max_seq_len = 512 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 @@ -149,7 +149,7 @@ for i in range(max_iters): m.eval() ppl = perplexity(model=m, data=val_data) writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\rPerplexity at {i}: {ppl}") + print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) m.eval() diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index 473cd74..11f52dc 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -114,12 +114,12 @@ class OpticalMul(_nn.Module): """ vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - if self.trainable_cylind_lens: vec_field = self._propagator_one(vec_field) vec_field = self._propagator_between(vec_field) else: vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/parallel.py b/src/optical_matrix_multiplication/parallel.py index b707415..7621c6f 100644 --- a/src/optical_matrix_multiplication/parallel.py +++ b/src/optical_matrix_multiplication/parallel.py @@ -129,34 +129,31 @@ class ScatterDataParallel(_nn.Module): Оптимизированный forward для attention матриц. Особенности: - - Scatter по batch dimension (0) вместо произвольного dim + - Scatter по batch dimension (1) вместо произвольного dim - Оба тензора scatter'ятся для согласованности размерностей - - Поддержка многомерных attention тензоров [batch, heads, seq, dim] + - Поддержка многомерных attention тензоров [batch, heads, seq, dim] ?? ''' # Определяем dimension для scatter на основе структуры тензоров if input.dim() >= 3 and other.dim() >= 3: # Для attention матриц scatter по batch dimension - scatter_dim = 0 + scatter_dim = 1 else: # Для обычных 2D матриц используем dim из kwargs или по умолчанию 2 scatter_dim = kwargs.get('dim', 2) - + # Подготовка модуля и данных self.module = self.module.to(self.devices[0]) + replicas = _nn.parallel.replicate(self.module, self.devices) # Scatter ОБОИХ тензоров для согласованности размерностей scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim) scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim) - - # Создаем реплики модуля - replicas = _nn.parallel.replicate(self.module, self.devices) - + # Формируем входные данные для каждого устройства # Убедимся, что все списки одинаковой длины min_len = min(len(scattered_input), len(scattered_other), len(replicas)) stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)] - # Параллельное вычисление outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input) diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index 68ee2c4..eaa061f 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -16,7 +16,7 @@ class Propagator(_ABC, _nn.Module): operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ - def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False, diagonal = False): + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) @@ -24,12 +24,10 @@ class Propagator(_ABC, _nn.Module): self._operator_X = _nn.Parameter(operator_X) self._operator_Y = _nn.Parameter(operator_Y) self._trainable = trainable - self._diagonal = diagonal else: self.register_buffer('_operator_X', operator_X, persistent=True) self.register_buffer('_operator_Y', operator_Y, persistent=True) self._trainable = trainable - self._diagonal = diagonal @property def operator_X(self) -> _torch.Tensor: @@ -111,13 +109,13 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - if self._diagonal: + if self._trainable: return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X) else: return self.operator_Y @ field @ self.operator_X def __repr__(self): - return f"Diag: {self._diagonal} Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" + return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ @@ -167,7 +165,8 @@ class PropagatorСylindLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase, + trainable = False): """ Конструктор класса цилиндрической линзы. @@ -177,10 +176,14 @@ class PropagatorСylindLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) - super(PropagatorСylindLens, self).__init__(operator_X, - operator_Y, - trainable, - diagonal=True) + if trainable: + super(PropagatorСylindLens, self).__init__(operator_X, + operator_Y, + trainable) + else: + super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y), + trainable) class PropagatorSinc(Propagator): """ diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/optics_char_gpt2_nokoef_newf.py index f865c18..d99f149 100644 --- a/src/optics_char_gpt2_nokoef_newf.py +++ b/src/optics_char_gpt2_nokoef_newf.py @@ -190,6 +190,8 @@ class OpticGPT2NOKoefNewF(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,