import torch import torch.nn as nn import optical_matrix_multiplication as omm import matplotlib.pyplot as plt device = 'cpu' h_dim = 64 pixel_size = 3.6e-6 batch_size = 100 test_lengths = [59, 64, 128, 256, 512] for max_seq_len in test_lengths: if max_seq_len < 512: sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim, right_matrix_width = pixel_size * max_seq_len, right_matrix_height = pixel_size * h_dim, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.01, lens_size = 8192) ) sim_output = omm.OpticalMul( omm.Config(right_matrix_count_columns = h_dim, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * h_dim, right_matrix_height = pixel_size * max_seq_len, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.01, lens_size = 8192) ) else: sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim, right_matrix_width = pixel_size * max_seq_len, right_matrix_height = pixel_size * h_dim, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.15, lens_size = 8192 * 2) ) sim_output = omm.OpticalMul( omm.Config(right_matrix_count_columns = h_dim, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * h_dim, right_matrix_height = pixel_size * max_seq_len, min_height_gap = pixel_size, right_matrix_split_x = 2, right_matrix_split_y = 2, left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, distance = 0.15, lens_size = 8192 * 2) ) def cko(x,y): x = x**2 y = y**2 return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100 sim_scores = sim_scores.to(device=device) sim_output = sim_output.to(device=device) q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1) true_scores = q @ k opt_scores = sim_scores(q, k) CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy() scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device) v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) true_o = scores @ v opt_o = sim_output(scores, v) CKO_o = cko(true_o, opt_o).detach().cpu().numpy() print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}") print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}")