You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
3.0 KiB
Python
63 lines
3.0 KiB
Python
2 years ago
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
from IPython.display import display, clear_output
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
def print_results(epoch, logs):
|
||
|
results = pd.DataFrame(data=np.array([v for k,v in logs.items()]).reshape(2,2),
|
||
|
columns=['loss', 'accuracy'],
|
||
|
index=['train','val'])
|
||
|
results = results.style.set_caption(f"{epoch}")
|
||
|
clear_output(wait=True)
|
||
|
display(results)
|
||
|
|
||
|
def train(X_train, y_train, X_val, y_val, model, epochs=50, lr=1e-2,verbose=False):
|
||
|
X_train = torch.tensor(X_train, dtype=torch.float32).cuda()
|
||
|
X_val = torch.tensor(X_val, dtype=torch.float32).cuda()
|
||
|
history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}
|
||
|
model = torch.compile(model).cuda()
|
||
|
criterion = nn.BCELoss()
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||
|
train_dataset = torch.utils.data.TensorDataset(X_train, torch.tensor(y_train))
|
||
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4096, shuffle=True)
|
||
|
val_dataset = torch.utils.data.TensorDataset(X_val, torch.tensor(y_val))
|
||
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4096, shuffle=False)
|
||
|
for i in range(epochs):
|
||
|
running_loss = 0.0
|
||
|
running_accuracy = 0.0
|
||
|
# Train
|
||
|
model.train()
|
||
|
for X_batch, y_batch in train_loader:
|
||
|
optimizer.zero_grad()
|
||
|
y_pred = model(X_batch.cuda()).cpu()
|
||
|
# print(y_pred.shape, y_batch.shape)
|
||
|
loss = criterion(y_pred, y_batch.float())
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
running_loss += loss.item() * X_batch.size(0)
|
||
|
running_accuracy += ((y_pred > 0.5) == y_batch).sum().item()
|
||
|
running_loss /= X_train.shape[0]
|
||
|
running_accuracy /= X_train.shape[0]
|
||
|
val_running_loss = 0.0
|
||
|
val_running_accuracy = 0.0
|
||
|
# Validate
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
for X_batch, y_batch in val_loader:
|
||
|
y_pred = model(X_batch.cuda()).cpu()
|
||
|
loss = criterion(y_pred, y_batch.float())
|
||
|
val_running_loss += loss.item() * X_batch.size(0)
|
||
|
val_running_accuracy += ((y_pred > 0.5) == y_batch).sum().item()
|
||
|
val_running_loss /= X_val.shape[0]
|
||
|
val_running_accuracy /= X_val.shape[0]
|
||
|
if verbose:
|
||
|
print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy})
|
||
|
history['loss'].append(running_loss)
|
||
|
history['accuracy'].append(running_accuracy)
|
||
|
history['val_loss'].append(val_running_loss)
|
||
|
history['val_accuracy'].append(val_running_accuracy)
|
||
|
if verbose:
|
||
|
print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy})
|
||
|
|
||
|
return {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy}, history
|