From 1ace114d0c39f28c22222833f9fa6ccfca22c24c Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 14 Feb 2026 14:48:59 +0000 Subject: [PATCH] optic mul error test --- src/basic_optic_mm_test.py | 93 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 src/basic_optic_mm_test.py diff --git a/src/basic_optic_mm_test.py b/src/basic_optic_mm_test.py new file mode 100644 index 0000000..bc0a883 --- /dev/null +++ b/src/basic_optic_mm_test.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import optical_matrix_multiplication as omm +import matplotlib.pyplot as plt +device = 'cpu' + +h_dim = 64 +pixel_size = 3.6e-6 +batch_size = 100 +test_lengths = [59, 64, 128, 256, 512] + +for max_seq_len in test_lengths: + if max_seq_len < 512: + sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * h_dim, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + lens_size = 8192) + ) + sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * h_dim, + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + lens_size = 8192) + ) + else: + sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * h_dim, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * h_dim, + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + + def cko(x,y): + x = x**2 + y = y**2 + return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100 + + sim_scores = sim_scores.to(device=device) + sim_output = sim_output.to(device=device) + + q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) + k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1) + true_scores = q @ k + opt_scores = sim_scores(q, k) + CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy() + + scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device) + v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) + true_o = scores @ v + opt_o = sim_output(scores, v) + CKO_o = cko(true_o, opt_o).detach().cpu().numpy() + + print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}") + print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}") \ No newline at end of file