diff --git a/src/train_gpt2.py b/src/train_gpt2.py index a5aa3d7..75de26f 100644 --- a/src/train_gpt2.py +++ b/src/train_gpt2.py @@ -127,7 +127,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 64 +max_seq_len = 512 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 @@ -187,20 +187,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -220,7 +244,10 @@ m = MODEL_CLASS( m = m.to(device) model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' writer.add_text('model', model_description, 0) -# TODO for all experiments + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) #################################### Checkpoint Function ######################################### @@ -293,7 +320,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -311,13 +338,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_focal_dist_lens_64.py b/src/train_optics_trainable_focal_dist_lens_64.py index 87a9558..883e705 100644 --- a/src/train_optics_trainable_focal_dist_lens_64.py +++ b/src/train_optics_trainable_focal_dist_lens_64.py @@ -268,6 +268,10 @@ print("Logs dir:", logs_dir) shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository script_snapshot_path.chmod(0o500) # with read-only permission +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) #################################### Dataset ######################################### @@ -303,18 +307,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): + model.eval() stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model ######################################### + def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) @@ -336,6 +370,37 @@ writer.add_text('model', model_description, 0) optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} + +#################################### Train ######################################### m.eval() task_prompts = [ @@ -373,22 +438,32 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) writer.add_text('completions/task', task_results, i) m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") @@ -397,4 +472,19 @@ print(completion) writer.add_text('completions', completion, i+1) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) print(task_results) -writer.add_text('completions/task', task_results, i+1) \ No newline at end of file +writer.add_text('completions/task', task_results, i+1) + +m.log_trainable_optic_params(writer, max_iters) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_128.py b/src/train_optics_trainable_lens_128.py index 8c3f2dc..9a551d3 100644 --- a/src/train_optics_trainable_lens_128.py +++ b/src/train_optics_trainable_lens_128.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 128 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_256.py b/src/train_optics_trainable_lens_256.py index 59e5f31..c802f4d 100644 --- a/src/train_optics_trainable_lens_256.py +++ b/src/train_optics_trainable_lens_256.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 256 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_512.py b/src/train_optics_trainable_lens_512.py index 9d081e2..b090e0c 100644 --- a/src/train_optics_trainable_lens_512.py +++ b/src/train_optics_trainable_lens_512.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_64.py b/src/train_optics_trainable_lens_64.py index bbe5a73..751147b 100644 --- a/src/train_optics_trainable_lens_64.py +++ b/src/train_optics_trainable_lens_64.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 64 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,23 +305,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) -#################################### Model #########################################mo +#################################### Model ######################################### + def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) @@ -415,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -434,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}")