mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
93 lines
2.3 KiB
Python
93 lines
2.3 KiB
Python
import torch
|
|
import torchvision
|
|
from dataset import CarvanaDataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
|
print("=> Saving checkpoint")
|
|
torch.save(state, filename)
|
|
|
|
def load_checkpoint(checkpoint, model):
|
|
print("=> Loading checkpoint")
|
|
model.load_state_dict(checkpoint["state_dict"])
|
|
|
|
def get_loaders(
|
|
train_dir,
|
|
train_maskdir,
|
|
val_dir,
|
|
val_maskdir,
|
|
batch_size,
|
|
train_transform,
|
|
val_transform,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
):
|
|
train_ds = CarvanaDataset(
|
|
image_dir=train_dir,
|
|
mask_dir=train_maskdir,
|
|
transform=train_transform,
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_ds,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
shuffle=True,
|
|
)
|
|
|
|
val_ds = CarvanaDataset(
|
|
image_dir=val_dir,
|
|
mask_dir=val_maskdir,
|
|
transform=val_transform,
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_ds,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
shuffle=False,
|
|
)
|
|
|
|
return train_loader, val_loader
|
|
|
|
def check_accuracy(loader, model, device="cuda"):
|
|
num_correct = 0
|
|
num_pixels = 0
|
|
dice_score = 0
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for x, y in loader:
|
|
x = x.to(device)
|
|
y = y.to(device).unsqueeze(1)
|
|
preds = torch.sigmoid(model(x))
|
|
preds = (preds > 0.5).float()
|
|
num_correct += (preds == y).sum()
|
|
num_pixels += torch.numel(preds)
|
|
dice_score += (2 * (preds * y).sum()) / (
|
|
(preds + y).sum() + 1e-8
|
|
)
|
|
|
|
print(
|
|
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
|
|
)
|
|
print(f"Dice score: {dice_score/len(loader)}")
|
|
model.train()
|
|
|
|
def save_predictions_as_imgs(
|
|
loader, model, folder="saved_images/", device="cuda"
|
|
):
|
|
model.eval()
|
|
for idx, (x, y) in enumerate(loader):
|
|
x = x.to(device=device)
|
|
with torch.no_grad():
|
|
preds = torch.sigmoid(model(x))
|
|
preds = (preds > 0.5).float()
|
|
torchvision.utils.save_image(
|
|
preds, f"{folder}/pred_{idx}.png"
|
|
)
|
|
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
|
|
|
|
model.train() |