You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
3.0 KiB
Python

2 years ago
import pandas as pd
import numpy as np
from IPython.display import display, clear_output
import torch
import torch.nn as nn
def print_results(epoch, logs):
results = pd.DataFrame(data=np.array([v for k,v in logs.items()]).reshape(2,2),
columns=['loss', 'accuracy'],
index=['train','val'])
results = results.style.set_caption(f"{epoch}")
clear_output(wait=True)
display(results)
def train(X_train, y_train, X_val, y_val, model, epochs=50, lr=1e-2,verbose=False):
X_train = torch.tensor(X_train, dtype=torch.float32).cuda()
X_val = torch.tensor(X_val, dtype=torch.float32).cuda()
history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}
model = torch.compile(model).cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataset = torch.utils.data.TensorDataset(X_train, torch.tensor(y_train))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4096, shuffle=True)
val_dataset = torch.utils.data.TensorDataset(X_val, torch.tensor(y_val))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4096, shuffle=False)
for i in range(epochs):
running_loss = 0.0
running_accuracy = 0.0
# Train
model.train()
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
y_pred = model(X_batch.cuda()).cpu()
# print(y_pred.shape, y_batch.shape)
loss = criterion(y_pred, y_batch.float())
loss.backward()
optimizer.step()
running_loss += loss.item() * X_batch.size(0)
running_accuracy += ((y_pred > 0.5) == y_batch).sum().item()
running_loss /= X_train.shape[0]
running_accuracy /= X_train.shape[0]
val_running_loss = 0.0
val_running_accuracy = 0.0
# Validate
model.eval()
with torch.no_grad():
for X_batch, y_batch in val_loader:
y_pred = model(X_batch.cuda()).cpu()
loss = criterion(y_pred, y_batch.float())
val_running_loss += loss.item() * X_batch.size(0)
val_running_accuracy += ((y_pred > 0.5) == y_batch).sum().item()
val_running_loss /= X_val.shape[0]
val_running_accuracy /= X_val.shape[0]
if verbose:
print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy})
history['loss'].append(running_loss)
history['accuracy'].append(running_accuracy)
history['val_loss'].append(val_running_loss)
history['val_accuracy'].append(val_running_accuracy)
if verbose:
print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy})
return {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy}, history