mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
import torch
|
|
import visdom
|
|
import os
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
dtype = torch.float32
|
|
|
|
|
|
def save_checkpoint(filename, model, optimizer, train_acc, epoch):
|
|
save_state = {
|
|
"state_dict": model.state_dict(),
|
|
"acc": train_acc,
|
|
"epoch": epoch + 1,
|
|
"optimizer": optimizer.state_dict(),
|
|
}
|
|
print()
|
|
print("Saving current parameters")
|
|
print("___________________________________________________________")
|
|
|
|
torch.save(save_state, filename)
|
|
|
|
|
|
def check_accuracy(loader, model):
|
|
if loader.dataset.train:
|
|
print("Checking accuracy on training or validation set")
|
|
else:
|
|
print("Checking accuracy on test set")
|
|
num_correct = 0
|
|
num_samples = 0
|
|
# model.eval() # set model to evaluation mode
|
|
with torch.no_grad():
|
|
for x, y in loader:
|
|
x = x.to(device=device, dtype=dtype) # move to device, e.g. GPU
|
|
y = y.to(device=device, dtype=torch.long)
|
|
scores = model(x)
|
|
_, preds = scores.max(1)
|
|
num_correct += (preds == y).sum()
|
|
num_samples += preds.size(0)
|
|
acc = (float(num_correct) / num_samples) * 100.0
|
|
print("Got %d / %d correct (%.2f)" % (num_correct, num_samples, acc))
|
|
return acc
|
|
|
|
|
|
def load_model(args, model, optimizer):
|
|
if args.resume:
|
|
model.eval()
|
|
if os.path.isfile(args.resume):
|
|
print("=> loading checkpoint '{}'".format(args.resume))
|
|
checkpoint = torch.load(args.resume)
|
|
start_epoch = checkpoint["epoch"]
|
|
best_acc = checkpoint["acc"]
|
|
model.load_state_dict(checkpoint["state_dict"])
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
print(
|
|
"=> loaded checkpoint '{}' (epoch {})".format(
|
|
args.resume, checkpoint["epoch"]
|
|
)
|
|
)
|
|
return model, optimizer, checkpoint, start_epoch, best_acc
|
|
else:
|
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
|
else:
|
|
print("No pretrained model. Starting from scratch!")
|
|
|
|
|
|
class visdom_plotting(object):
|
|
def __init__(self):
|
|
self.viz = visdom.Visdom()
|
|
|
|
self.cur_batch_win = None
|
|
self.cur_batch_win_opts = {
|
|
"title": "Epoch Loss Trace",
|
|
"xlabel": "Batch Number",
|
|
"ylabel": "Loss",
|
|
"width": 600,
|
|
"height": 400,
|
|
}
|
|
|
|
self.cur_validation_acc = None
|
|
self.cur_validation_acc_opts = {
|
|
"title": "Validation accuracy",
|
|
"xlabel": "Epochs",
|
|
"ylabel": "Validation Accuracy",
|
|
"width": 600,
|
|
"height": 400,
|
|
}
|
|
|
|
self.cur_training_acc = None
|
|
self.cur_training_acc_opts = {
|
|
"title": "Training accuracy",
|
|
"xlabel": "Epochs",
|
|
"ylabel": "Train Accuracy",
|
|
"width": 600,
|
|
"height": 400,
|
|
}
|
|
|
|
def create_plot(
|
|
self, loss_list, batch_list, validation_acc_list, epoch_list, training_acc_list
|
|
):
|
|
|
|
if self.viz.check_connection():
|
|
self.cur_batch_win = self.viz.line(
|
|
torch.FloatTensor(loss_list),
|
|
torch.FloatTensor(batch_list),
|
|
win=self.cur_batch_win,
|
|
name="current_batch_loss",
|
|
update=(None if self.cur_batch_win is None else "replace"),
|
|
opts=self.cur_batch_win_opts,
|
|
)
|
|
|
|
self.cur_validation_acc = self.viz.line(
|
|
torch.FloatTensor(validation_acc_list),
|
|
torch.FloatTensor(epoch_list),
|
|
win=self.cur_validation_acc,
|
|
name="current_validation_accuracy",
|
|
update=(None if self.cur_validation_acc is None else "replace"),
|
|
opts=self.cur_validation_acc_opts,
|
|
)
|
|
|
|
self.cur_training_acc = self.viz.line(
|
|
torch.FloatTensor(training_acc_list),
|
|
torch.FloatTensor(epoch_list),
|
|
win=self.cur_validation_acc,
|
|
name="current_training_accuracy",
|
|
update=(None if self.cur_training_acc is None else "replace"),
|
|
opts=self.cur_training_acc_opts,
|
|
)
|
|
|
|
|
|
#
|