From ae581a64e6b1a9350adbdc726bc5a860cb6f5370 Mon Sep 17 00:00:00 2001 From: dino Date: Tue, 13 Sep 2022 13:04:49 +0200 Subject: [PATCH] vae --- ML/Pytorch/more_advanced/VAE/model.py | 45 ++++++++++ ML/Pytorch/more_advanced/VAE/train.py | 86 +++++++++++++++++++ .../transformer_from_scratch.py | 18 ++-- 3 files changed, 140 insertions(+), 9 deletions(-) create mode 100644 ML/Pytorch/more_advanced/VAE/model.py create mode 100644 ML/Pytorch/more_advanced/VAE/train.py diff --git a/ML/Pytorch/more_advanced/VAE/model.py b/ML/Pytorch/more_advanced/VAE/model.py new file mode 100644 index 0000000..c6dd5af --- /dev/null +++ b/ML/Pytorch/more_advanced/VAE/model.py @@ -0,0 +1,45 @@ +import torch +from torch import nn + + +class VariationalAutoEncoder(nn.Module): + def __init__(self, input_dim, h_dim=200, z_dim=20): + super().__init__() + # encoder + self.img_2hid = nn.Linear(input_dim, h_dim) + self.hid_2mu = nn.Linear(h_dim, z_dim) + self.hid_2sigma = nn.Linear(h_dim, z_dim) + + # decoder + self.z_2hid = nn.Linear(z_dim, h_dim) + self.hid_2img = nn.Linear(h_dim, input_dim) + + self.relu = nn.ReLU() + + def encode(self, x): + h = self.relu(self.img_2hid(x)) + mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) + return mu, sigma + + def decode(self, z): + h = self.relu(self.z_2hid(z)) + return torch.sigmoid(self.hid_2img(h)) + + def forward(self, x): + mu, sigma = self.encode(x) + epsilon = torch.randn_like(sigma) + z_new = mu + sigma*epsilon + x_reconstructed = self.decode(z_new) + return x_reconstructed, mu, sigma + + +if __name__ == "__main__": + x = torch.randn(4, 28*28) + vae = VariationalAutoEncoder(input_dim=784) + x_reconstructed, mu, sigma = vae(x) + print(x_reconstructed.shape) + print(mu.shape) + print(sigma.shape) + + + diff --git a/ML/Pytorch/more_advanced/VAE/train.py b/ML/Pytorch/more_advanced/VAE/train.py new file mode 100644 index 0000000..eb0949f --- /dev/null +++ b/ML/Pytorch/more_advanced/VAE/train.py @@ -0,0 +1,86 @@ +import torch +import torchvision.datasets as datasets # Standard datasets +from tqdm import tqdm +from torch import nn, optim +from model import VariationalAutoEncoder +from torchvision import transforms +from torchvision.utils import save_image +from torch.utils.data import DataLoader + +# Configuration +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +INPUT_DIM = 784 +H_DIM = 200 +Z_DIM = 20 +NUM_EPOCHS = 10 +BATCH_SIZE = 32 +LR_RATE = 3e-4 # Karpathy constant + +# Dataset Loading +dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True) +train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) +model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE) +optimizer = optim.Adam(model.parameters(), lr=LR_RATE) +loss_fn = nn.BCELoss(reduction="sum") + +# Start Training +for epoch in range(NUM_EPOCHS): + loop = tqdm(enumerate(train_loader)) + for i, (x, _) in loop: + # Forward pass + x = x.to(DEVICE).view(x.shape[0], INPUT_DIM) + x_reconstructed, mu, sigma = model(x) + + # Compute loss + reconstruction_loss = loss_fn(x_reconstructed, x) + kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) + + # Backprop + loss = reconstruction_loss + kl_div + optimizer.zero_grad() + loss.backward() + optimizer.step() + loop.set_postfix(loss=loss.item()) + + +model = model.to("cpu") +def inference(digit, num_examples=1): + """ + Generates (num_examples) of a particular digit. + Specifically we extract an example of each digit, + then after we have the mu, sigma representation for + each digit we can sample from that. + + After we sample we can run the decoder part of the VAE + and generate examples. + """ + images = [] + idx = 0 + for x, y in dataset: + if y == idx: + images.append(x) + idx += 1 + if idx == 10: + break + + encodings_digit = [] + for d in range(10): + with torch.no_grad(): + mu, sigma = model.encode(images[d].view(1, 784)) + encodings_digit.append((mu, sigma)) + + mu, sigma = encodings_digit[digit] + for example in range(num_examples): + epsilon = torch.randn_like(sigma) + z = mu + sigma * epsilon + out = model.decode(z) + out = out.view(-1, 1, 28, 28) + save_image(out, f"generated_{digit}_ex{example}.png") + +for idx in range(10): + inference(idx, num_examples=5) + + + + + diff --git a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py index 80396df..dc1e489 100644 --- a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py +++ b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py @@ -23,10 +23,10 @@ class SelfAttention(nn.Module): self.head_dim * heads == embed_size ), "Embedding size needs to be divisible by heads" - self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) - self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) - self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) - self.fc_out = nn.Linear(heads * self.head_dim, embed_size) + self.values = nn.Linear(embed_size, embed_size) + self.keys = nn.Linear(embed_size, embed_size) + self.queries = nn.Linear(embed_size, embed_size) + self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, values, keys, query, mask): # Get number of training examples @@ -34,14 +34,14 @@ class SelfAttention(nn.Module): value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] + values = self.values(values) # (N, value_len, embed_size) + keys = self.keys(keys) # (N, key_len, embed_size) + queries = self.queries(query) # (N, query_len, embed_size) + # Split the embedding into self.heads different pieces values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) - query = query.reshape(N, query_len, self.heads, self.head_dim) - - values = self.values(values) # (N, value_len, heads, head_dim) - keys = self.keys(keys) # (N, key_len, heads, head_dim) - queries = self.queries(query) # (N, query_len, heads, heads_dim) + queries = queries.reshape(N, query_len, self.heads, self.head_dim) # Einsum does matrix mult. for query*keys for each training example # with every other training example, don't be confused by einsum