checkpointing, batched perplexity.

pull/1/head
Vladimir Protsenko 1 month ago
parent 064b9e14c8
commit 85249dfbba

@ -127,7 +127,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200 eval_iters = 200
layers_num = 2 layers_num = 2
h_dim = 64 h_dim = 64
max_seq_len = 64 max_seq_len = 512
num_heads = 1 num_heads = 1
dropout_rate = 0.1 dropout_rate = 0.1
pixel_size = 3.6e-6 pixel_size = 3.6e-6
@ -187,20 +187,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval() model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
total_loss_sum = 0.0 total_loss_sum = 0.0
total_tokens_count = 0 total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride): # Precompute all valid start positions
x = data[i:(i+max_seq_len)].to(device) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
y = data[(i+1):(i+max_seq_len+1)].to(device) total_sequences = len(start_positions)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel() # Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count) return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo #################################### Model #########################################mo
@ -220,7 +244,10 @@ m = MODEL_CLASS(
m = m.to(device) m = m.to(device)
model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}'
writer.add_text('model', model_description, 0) writer.add_text('model', model_description, 0)
# TODO for all experiments
print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.")
print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.")
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function ######################################### #################################### Checkpoint Function #########################################
@ -293,7 +320,7 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
@ -311,13 +338,13 @@ for i in range(max_iters):
) )
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")

@ -268,6 +268,10 @@ print("Logs dir:", logs_dir)
shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository
script_snapshot_path.chmod(0o500) # with read-only permission script_snapshot_path.chmod(0o500) # with read-only permission
# Create standalone checkpoints directory with your specified format
checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Checkpoints dir:", checkpoints_dir)
#################################### Dataset ######################################### #################################### Dataset #########################################
@ -303,18 +307,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
losses = [] total_loss_sum = 0.0
for i in range(0, len(data)-max_seq_len-1, stride): total_tokens_count = 0
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device) # Precompute all valid start positions
logits, loss = model(x[None,...], y[None,...]) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
losses.append(loss.item()) total_sequences = len(start_positions)
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses)) # Process sequences in batches
for i in range(0, total_sequences, batch_size):
#################################### Model #########################################mo batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################
def complete(m, start_idxs=[0], max_new_tokens=100): def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device) start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
@ -336,6 +370,37 @@ writer.add_text('model', model_description, 0)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
#################################### Checkpoint Function #########################################
def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir):
"""Save model checkpoint with complete training state"""
checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt'
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': config,
'wtoi': wtoi,
'itow': itow,
}, checkpoint_path)
# Training config for checkpointing
training_config = {
'vocab_size': vocab_size,
'layers_num': layers_num,
'h_dim': h_dim,
'max_seq_len': max_seq_len,
'num_heads': num_heads,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'learning_rate': learning_rate,
'gradient_accumulation_steps': gradient_accumulation_steps,
'pixel_size': pixel_size,
'max_iters': max_iters,
}
#################################### Train #########################################
m.eval() m.eval()
task_prompts = [ task_prompts = [
@ -373,22 +438,32 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
writer.add_text('completions/task', task_results, i) writer.add_text('completions/task', task_results, i)
m.log_trainable_optic_params(writer, i) m.log_trainable_optic_params(writer, i)
save_checkpoint(
model=m,
optimizer=optimizer,
step=i,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")
@ -397,4 +472,19 @@ print(completion)
writer.add_text('completions', completion, i+1) writer.add_text('completions', completion, i+1)
task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts])
print(task_results) print(task_results)
writer.add_text('completions/task', task_results, i+1) writer.add_text('completions/task', task_results, i+1)
m.log_trainable_optic_params(writer, max_iters)
# Save final checkpoint
save_checkpoint(
model=m,
optimizer=optimizer,
step=max_iters,
loss=accumulated_loss,
config=training_config,
wtoi=wtoi,
itow=itow,
checkpoint_dir=checkpoints_dir
)
print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}")

@ -12,6 +12,21 @@ import shutil
seed = 1337 seed = 1337
torch.manual_seed(seed) torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL ############################################################# ############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2): def new_formula(sim, tensor_1, tensor_2):
@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module):
################################################################################################### ###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 128
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens") train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens") val_data_path = Path("./data/wiki.valid.tokens")
@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval() model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
total_loss_sum = 0.0 total_loss_sum = 0.0
total_tokens_count = 0 total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride): # Precompute all valid start positions
x = data[i:(i+max_seq_len)].to(device) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
y = data[(i+1):(i+max_seq_len+1)].to(device) total_sequences = len(start_positions)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel() # Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count) return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo #################################### Model #########################################mo
@ -416,7 +439,7 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
@ -435,13 +458,13 @@ for i in range(max_iters):
) )
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")

@ -12,6 +12,21 @@ import shutil
seed = 1337 seed = 1337
torch.manual_seed(seed) torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL ############################################################# ############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2): def new_formula(sim, tensor_1, tensor_2):
@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module):
################################################################################################### ###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens") train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens") val_data_path = Path("./data/wiki.valid.tokens")
@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval() model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
total_loss_sum = 0.0 total_loss_sum = 0.0
total_tokens_count = 0 total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride): # Precompute all valid start positions
x = data[i:(i+max_seq_len)].to(device) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
y = data[(i+1):(i+max_seq_len+1)].to(device) total_sequences = len(start_positions)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel() # Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count) return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo #################################### Model #########################################mo
@ -416,7 +439,7 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
@ -435,13 +458,13 @@ for i in range(max_iters):
) )
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")

@ -12,6 +12,21 @@ import shutil
seed = 1337 seed = 1337
torch.manual_seed(seed) torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL ############################################################# ############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2): def new_formula(sim, tensor_1, tensor_2):
@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module):
################################################################################################### ###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens") train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens") val_data_path = Path("./data/wiki.valid.tokens")
@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval() model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
total_loss_sum = 0.0 total_loss_sum = 0.0
total_tokens_count = 0 total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride): # Precompute all valid start positions
x = data[i:(i+max_seq_len)].to(device) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
y = data[(i+1):(i+max_seq_len+1)].to(device) total_sequences = len(start_positions)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel() # Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count) return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo #################################### Model #########################################mo
@ -416,7 +439,7 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
@ -435,13 +458,13 @@ for i in range(max_iters):
) )
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")

@ -12,6 +12,21 @@ import shutil
seed = 1337 seed = 1337
torch.manual_seed(seed) torch.manual_seed(seed)
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
############################### MODEL ############################################################# ############################### MODEL #############################################################
def new_formula(sim, tensor_1, tensor_2): def new_formula(sim, tensor_1, tensor_2):
@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module):
################################################################################################### ###################################################################################################
batch_size = 50
gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = int(4e4) #40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 64
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
assert batch_size % gradient_accumulation_steps == 0
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64
MODEL_CLASS = OpticGPT2TrainableScalarAndLens MODEL_CLASS = OpticGPT2TrainableScalarAndLens
train_data_path = Path("./data/wiki.train.tokens") train_data_path = Path("./data/wiki.train.tokens")
val_data_path = Path("./data/wiki.valid.tokens") val_data_path = Path("./data/wiki.valid.tokens")
@ -306,23 +305,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad() @torch.no_grad()
def perplexity(model, data): def perplexity(model, data, batch_size=32):
model.eval() model.eval()
stride = max(1, len(data) // 10000) stride = max(1, len(data) // 10000)
total_loss_sum = 0.0 total_loss_sum = 0.0
total_tokens_count = 0 total_tokens_count = 0
losses = []
for i in range(0, len(data)-max_seq_len-1, stride): # Precompute all valid start positions
x = data[i:(i+max_seq_len)].to(device) start_positions = list(range(0, len(data) - max_seq_len - 1, stride))
y = data[(i+1):(i+max_seq_len+1)].to(device) total_sequences = len(start_positions)
logits, mean_loss = model(x[None,...], y[None,...])
num_tokens = y.numel() # Process sequences in batches
for i in range(0, total_sequences, batch_size):
batch_starts = start_positions[i:min(i + batch_size, total_sequences)]
# Efficiently stack sequences into batch tensors
x_batch = torch.stack([
data[start:start + max_seq_len]
for start in batch_starts
]).to(device)
y_batch = torch.stack([
data[start + 1:start + max_seq_len + 1]
for start in batch_starts
]).to(device)
# Forward pass (model should return mean loss averaged over all tokens in batch)
_, mean_loss = model(x_batch, y_batch)
# Accumulate weighted loss (mean_loss is already averaged over tokens)
num_tokens = y_batch.numel()
total_loss_sum += mean_loss.item() * num_tokens total_loss_sum += mean_loss.item() * num_tokens
total_tokens_count += num_tokens total_tokens_count += num_tokens
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
# Progress update
processed = min(i + batch_size, total_sequences)
print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True)
print() # Final newline
return np.exp(total_loss_sum / total_tokens_count) return np.exp(total_loss_sum / total_tokens_count)
#################################### Model #########################################mo #################################### Model #########################################
def complete(m, start_idxs=[0], max_new_tokens=100): def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device) start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
@ -415,7 +439,7 @@ for i in range(max_iters):
print(f"\r{i}/{max_iters} {accumulated_loss}", end="") print(f"\r{i}/{max_iters} {accumulated_loss}", end="")
if i % 5000 == 0: if i % 5000 == 0:
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
writer.add_scalar('val_perplexity', ppl.item(), i) writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i)
@ -434,13 +458,13 @@ for i in range(max_iters):
) )
m.eval() m.eval()
ppl = perplexity(model=m, data=val_data) ppl = perplexity(model=m, data=val_data, batch_size=batch_size)
print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{i+1}/{max_iters} {accumulated_loss}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', accumulated_loss, i) writer.add_scalar('loss', accumulated_loss, i)
ppl = perplexity(model=m, data=test_data) ppl = perplexity(model=m, data=test_data, batch_size=batch_size)
writer.add_scalar('test_perplexity', ppl.item(), i+1) writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}") print(f"\rTest Perplexity at {i}: {ppl}")

Loading…
Cancel
Save