mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76c0c3e8ea | ||
|
|
1982b04dc1 | ||
|
|
4455010336 | ||
|
|
96b31b195f | ||
|
|
aeba262563 | ||
|
|
c543867340 | ||
|
|
5ccb93b033 | ||
|
|
1b310a3422 | ||
|
|
5280b33a5c | ||
|
|
0e22471b42 | ||
|
|
799c95c83f |
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
__pycache__
|
||||
.git
|
||||
.gitignore
|
||||
.vscode
|
||||
.idea
|
||||
*.pyc
|
||||
*~
|
||||
data/
|
||||
secrets.env
|
||||
5
Dockerfile
Normal file
5
Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
||||
FROM pytorch/pytorch:2.1.1-cuda12.1-cudnn8-runtime
|
||||
RUN pip install --upgrade pip
|
||||
WORKDIR /code
|
||||
COPY requirements.txt /code/
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f54ecf0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"# HuggingFace Tutorial Series\n",
|
||||
"- 1. What is Huggingface?\n",
|
||||
"- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
|
||||
"- 3. Using the HuggingFace Pipeline (High level feature)\n",
|
||||
"- 4. How the pipeline works at a lower level\n",
|
||||
"- 5. HuggingFace Datasets\n",
|
||||
"- 6. HuggingFace Tokenizer\n",
|
||||
"- 7. HuggingFace Evaluate\n",
|
||||
"- 8. HuggingFace Trainer\n",
|
||||
"- 9. Putting it together to finetune a news article summarizer\n",
|
||||
"- 10. Making it more general and robust with Lightning and custom data loading\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"torch.set_float32_matmul_precision(\"medium\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "418cb03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class cnn_dailymail(Dataset):\n",
|
||||
" def __init__(self, csv_file, tokenizer, max_length=512):\n",
|
||||
" self.data = pd.read_csv(csv_file)\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" article = self.data.loc[idx, 'article']\n",
|
||||
" highlights = self.data.loc[idx, 'highlights']\n",
|
||||
"\n",
|
||||
" inputs = self.tokenizer(\n",
|
||||
" article,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
" targets = self.tokenizer(\n",
|
||||
" highlights,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" 'input_ids': inputs['input_ids'].squeeze(),\n",
|
||||
" 'attention_mask': inputs['attention_mask'].squeeze(),\n",
|
||||
" 'labels': targets['input_ids'].squeeze()\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaa62755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
|
||||
" super().__init__()\n",
|
||||
" self.train_csv = train_csv\n",
|
||||
" self.val_csv = val_csv\n",
|
||||
" self.test_csv = test_csv\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" if stage in ('fit', None):\n",
|
||||
" self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
|
||||
" self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
|
||||
" if stage in ('test', None):\n",
|
||||
" self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
|
||||
" \n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" \n",
|
||||
" # Save logits and labels as instance attributes\n",
|
||||
" if not hasattr(self, \"logits\"):\n",
|
||||
" self.logits = logits\n",
|
||||
" else:\n",
|
||||
" self.logits = torch.cat((self.logits, logits), dim=0)\n",
|
||||
" \n",
|
||||
" if not hasattr(self, \"labels\"):\n",
|
||||
" self.labels = labels\n",
|
||||
" else:\n",
|
||||
" self.labels = torch.cat((self.labels, labels), dim=0)\n",
|
||||
" \n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def on_validation_epoch_end(self):\n",
|
||||
" # Convert logits to predicted token IDs\n",
|
||||
" pred_token_ids = self.logits.argmax(dim=-1)\n",
|
||||
"\n",
|
||||
" # Decode predictions and labels using the saved instance attributes\n",
|
||||
" decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Compute ROUGE scores\n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
"\n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
"\n",
|
||||
" # Clear logits and labels instance attributes for the next validation epoch\n",
|
||||
" del self.logits\n",
|
||||
" del self.labels\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# File paths\n",
|
||||
"train_csv = \"train.csv\"\n",
|
||||
"val_csv = \"validation.csv\"\n",
|
||||
"test_csv = \"test.csv\"\n",
|
||||
"\n",
|
||||
"# Create the data module\n",
|
||||
"dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
|
||||
"dm.setup()\n",
|
||||
"\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5d3d684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0494596",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### next steps:\n",
|
||||
"* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
|
||||
"\n",
|
||||
"#### what we've done:\n",
|
||||
"* Change the data loading so it's more general, meaning on the fly loading from disk\n",
|
||||
"* add torch.compile\n",
|
||||
"* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
|
||||
"* add tensorboard visualization\n",
|
||||
"* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
|
||||
"* 2. Create an inference step, send in news article -> get summary, check that it works\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0f9b71ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04530b1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the LightningDataModule\n",
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" def prepare_data(self):\n",
|
||||
" # Download and preprocess the data\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
" \n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" # Load and preprocess the data\n",
|
||||
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
" self.train_ds = train_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.val_ds = val_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def preprocess_function(self, batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def validation_epoch_end(self, outputs):\n",
|
||||
" decoded_preds = []\n",
|
||||
" decoded_labels = []\n",
|
||||
" for output in outputs:\n",
|
||||
" logits = output['logits']\n",
|
||||
" labels = output['labels']\n",
|
||||
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
|
||||
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"\n",
|
||||
" 0%| | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
|
||||
" 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
|
||||
" 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
|
||||
" 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
|
||||
" 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
|
||||
" 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
|
||||
" 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
|
||||
" 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
|
||||
" 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
|
||||
" 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
|
||||
" 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
|
||||
" 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
|
||||
" 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
|
||||
" 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
|
||||
" 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
|
||||
" 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
|
||||
" 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
|
||||
" 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
|
||||
" 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
|
||||
" 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
|
||||
" 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
|
||||
" 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
|
||||
" 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
|
||||
" 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
|
||||
" 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
|
||||
" 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
|
||||
" 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
|
||||
" 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
|
||||
" 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
|
||||
" 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
|
||||
" 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
|
||||
" 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
|
||||
" 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
|
||||
" 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
|
||||
" 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
|
||||
" 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
|
||||
" 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
|
||||
" 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
|
||||
" 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
|
||||
" 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
|
||||
" 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
|
||||
" 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
|
||||
" 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
|
||||
" 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
|
||||
" 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
|
||||
" 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
|
||||
" 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
|
||||
" 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
|
||||
" 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
|
||||
" 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
|
||||
" 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
|
||||
" 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
|
||||
" 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
|
||||
" 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
|
||||
" 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
|
||||
" 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
|
||||
" 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
|
||||
" 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
|
||||
" 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
|
||||
" 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
|
||||
" 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
|
||||
" 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
|
||||
" 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
|
||||
" 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
|
||||
" 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
|
||||
" 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
|
||||
" 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
|
||||
" 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
|
||||
" 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
|
||||
" 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
|
||||
" 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
|
||||
" 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
|
||||
" 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
|
||||
" 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
|
||||
" 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
|
||||
" 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
|
||||
" 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
|
||||
" 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
|
||||
" 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
|
||||
" 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
|
||||
" 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
|
||||
" 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
|
||||
" 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
|
||||
" 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
|
||||
" 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
|
||||
" 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
|
||||
" 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
|
||||
" 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
|
||||
" 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
|
||||
" 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
|
||||
" 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
|
||||
" 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
|
||||
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
|
||||
" 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
|
||||
" 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
|
||||
" 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
|
||||
" 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
|
||||
" 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
|
||||
" 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
|
||||
" 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
|
||||
" 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
|
||||
" 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
|
||||
" 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
|
||||
" 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
|
||||
" 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
|
||||
" 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
|
||||
" 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
|
||||
" 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
|
||||
" 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
|
||||
" 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
|
||||
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
|
||||
" 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
|
||||
" 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
|
||||
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
|
||||
" 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
|
||||
" 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
|
||||
" 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
|
||||
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
|
||||
" 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
|
||||
" 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
|
||||
" 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
|
||||
"\n",
|
||||
" 0%| | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
|
||||
" 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
|
||||
" 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
|
||||
" 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"0 | model | T5ForConditionalGeneration | 60.5 M\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"60.5 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"60.5 M Total params\n",
|
||||
"242.026 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'list' object has no attribute 'size'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_float32_matmul_precision(\"medium\")\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
|
||||
"dm = MyDataModule(batch_size=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1395d5d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,644 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "7d5e92c6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
|
||||
"example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
|
||||
"\n",
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"print(ner_entity_results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "bf67ee76",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Turtle s Original Cara mel Chocolate P eca n luster\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"\n",
|
||||
"# Initialize the entity words list with an empty string\n",
|
||||
"entity_words = [\"\"]\n",
|
||||
"\n",
|
||||
"# Loop through each dictionary in the list and extract the entity word\n",
|
||||
"for result in ner_entity_results:\n",
|
||||
" if result[\"entity\"] == \"B-FOOD\":\n",
|
||||
" entity_words.append(result[\"word\"])\n",
|
||||
" elif result[\"entity\"] == \"I-FOOD\":\n",
|
||||
" entity_words[-1] += \" \" + result[\"word\"]\n",
|
||||
"\n",
|
||||
"# Remove any remaining ## symbols and extra spaces\n",
|
||||
"entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
|
||||
"\n",
|
||||
"# Join the entity words into a single string\n",
|
||||
"output = \" \".join(entity_words)\n",
|
||||
"\n",
|
||||
"print(output)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc8e5ea0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"print(torch.cuda.is_available())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8a1e039",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ad73024",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = pipeline(\"zero-shot-classification\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "04f7e02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier(\n",
|
||||
" \"This is a course about the Transformers library\",\n",
|
||||
" candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6fb246c2",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4e174f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# Define input text and pre-trained model checkpoint\n",
|
||||
"text = \"My name is wolfgang and I live in berlin\"\n",
|
||||
"checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
|
||||
"\n",
|
||||
"# Instantiate tokenizer and encode input text\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# Instantiate model and generate output\n",
|
||||
"model = AutoModel.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"print(outputs[0].shape)\n",
|
||||
"\n",
|
||||
"# Instantiate token classification model and generate predictions\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
|
||||
"print(predictions)\n",
|
||||
"print(model.config.id2label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8212bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# prepare input\n",
|
||||
"text = \"Replace me by any text you'd like.\"\n",
|
||||
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"# forward pass\n",
|
||||
"output = model(**encoded_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "314cba41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"# Load the pre-trained tokenizer and model\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# Define the input sentence with a masked token\n",
|
||||
"text = \"I want to <mask> a new car tomorrow.\"\n",
|
||||
"\n",
|
||||
"# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
|
||||
"encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"print(output.logits.shape)\n",
|
||||
"print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
|
||||
"\n",
|
||||
"# Extract the predicted probabilities for the masked token\n",
|
||||
"predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
|
||||
"predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
|
||||
"\n",
|
||||
"# Get the top-k most probable predictions for the masked token\n",
|
||||
"k = 5\n",
|
||||
"top_k = torch.topk(predicted_probabilities, k)\n",
|
||||
"for i in range(k):\n",
|
||||
" token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
|
||||
" score = top_k.values[i].item()\n",
|
||||
" print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6187e77e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
|
||||
"\n",
|
||||
"sequences = [\n",
|
||||
" \"Using a Transformer network is simple\",\n",
|
||||
" \"The quick brown fox jumps over the lazy dog\",\n",
|
||||
" \"To be or not to be, that is the question\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
|
||||
"inputs = tokenizer(\n",
|
||||
" sequences,\n",
|
||||
" padding=True,\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the resulting input IDs and attention masks\n",
|
||||
"print(inputs['input_ids'])\n",
|
||||
"print(inputs['attention_mask'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc259c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43466db6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Huggingface:\n",
|
||||
"\n",
|
||||
"1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
|
||||
"2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
|
||||
"3. Load dataset\n",
|
||||
"4. How to finetune\n",
|
||||
"5. How to evaluate\n",
|
||||
"6. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97c474f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ed5d8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
|
||||
"\n",
|
||||
"# Same as before\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
||||
"sequences = [\n",
|
||||
" \"I've been waiting for a HuggingFace course my whole life.\",\n",
|
||||
" \"This course is amazing!\",\n",
|
||||
"]\n",
|
||||
"batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# This is new\n",
|
||||
"batch[\"labels\"] = torch.tensor([1, 1])\n",
|
||||
"\n",
|
||||
"optimizer = AdamW(model.parameters())\n",
|
||||
"loss = model(**batch).loss\n",
|
||||
"loss.backward()\n",
|
||||
"optimizer.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c598624f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd296227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_dataset = raw_datasets[\"train\"]\n",
|
||||
"raw_train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e462947a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
|
||||
"\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"\n",
|
||||
"def tokenize_function(example):\n",
|
||||
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
|
||||
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import evaluate\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e2795dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3af29cd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "817f644e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42819a6c",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb5986b0",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"batch_size=32\n",
|
||||
"\n",
|
||||
"# Define the generator function to preprocess the data in batches\n",
|
||||
"def preprocess_generator(examples):\n",
|
||||
" for i in range(0, len(examples[\"article\"]), batch_size):\n",
|
||||
" batch = examples[\"article\"][i:i+batch_size]\n",
|
||||
" targets = examples[\"highlights\"][i:i+batch_size]\n",
|
||||
" model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" yield model_inputs\n",
|
||||
"\n",
|
||||
"def preprocess_function(examples):\n",
|
||||
" articles = [ex for ex in examples[\"article\"]]\n",
|
||||
" summaries = [ex for ex in examples[\"highlights\"]]\n",
|
||||
"\n",
|
||||
" model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" \n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" return model_inputs\n",
|
||||
" \n",
|
||||
"# Load the dataset\n",
|
||||
"raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
|
||||
"preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
|
||||
"\n",
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" evaluation_strategy = \"epoch\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=1000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_ds,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7d62583e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d310a7b3",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessed_datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99d422cc",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" \n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=preprocessed_datasets[\"train\"],\n",
|
||||
" eval_dataset=preprocessed_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a5e97b57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install nltk\n",
|
||||
"!pip install rouge_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "558c3e66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Goal:\n",
|
||||
"\n",
|
||||
"1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
|
||||
"* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
|
||||
"\n",
|
||||
"2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "624d49ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f54ecf0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"# HuggingFace Tutorial Series\n",
|
||||
"- 1. What is Huggingface?\n",
|
||||
"- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
|
||||
"- 3. Using the HuggingFace Pipeline (High level feature)\n",
|
||||
"- 4. How the pipeline works at a lower level\n",
|
||||
"- 5. HuggingFace Datasets\n",
|
||||
"- 6. HuggingFace Tokenizer\n",
|
||||
"- 7. HuggingFace Evaluate\n",
|
||||
"- 8. HuggingFace Trainer\n",
|
||||
"- 9. Putting it together to finetune a news article summarizer\n",
|
||||
"- 10. Making it more general and robust with Lightning and custom data loading\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"torch.set_float32_matmul_precision(\"medium\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "418cb03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class cnn_dailymail(Dataset):\n",
|
||||
" def __init__(self, csv_file, tokenizer, max_length=512):\n",
|
||||
" self.data = pd.read_csv(csv_file)\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" article = self.data.loc[idx, 'article']\n",
|
||||
" highlights = self.data.loc[idx, 'highlights']\n",
|
||||
"\n",
|
||||
" inputs = self.tokenizer(\n",
|
||||
" article,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
" targets = self.tokenizer(\n",
|
||||
" highlights,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" 'input_ids': inputs['input_ids'].squeeze(),\n",
|
||||
" 'attention_mask': inputs['attention_mask'].squeeze(),\n",
|
||||
" 'labels': targets['input_ids'].squeeze()\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaa62755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
|
||||
" super().__init__()\n",
|
||||
" self.train_csv = train_csv\n",
|
||||
" self.val_csv = val_csv\n",
|
||||
" self.test_csv = test_csv\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" if stage in ('fit', None):\n",
|
||||
" self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
|
||||
" self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
|
||||
" if stage in ('test', None):\n",
|
||||
" self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
|
||||
" \n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" \n",
|
||||
" # Save logits and labels as instance attributes\n",
|
||||
" if not hasattr(self, \"logits\"):\n",
|
||||
" self.logits = logits\n",
|
||||
" else:\n",
|
||||
" self.logits = torch.cat((self.logits, logits), dim=0)\n",
|
||||
" \n",
|
||||
" if not hasattr(self, \"labels\"):\n",
|
||||
" self.labels = labels\n",
|
||||
" else:\n",
|
||||
" self.labels = torch.cat((self.labels, labels), dim=0)\n",
|
||||
" \n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def on_validation_epoch_end(self):\n",
|
||||
" # Convert logits to predicted token IDs\n",
|
||||
" pred_token_ids = self.logits.argmax(dim=-1)\n",
|
||||
"\n",
|
||||
" # Decode predictions and labels using the saved instance attributes\n",
|
||||
" decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Compute ROUGE scores\n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
"\n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
"\n",
|
||||
" # Clear logits and labels instance attributes for the next validation epoch\n",
|
||||
" del self.logits\n",
|
||||
" del self.labels\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# File paths\n",
|
||||
"train_csv = \"train.csv\"\n",
|
||||
"val_csv = \"validation.csv\"\n",
|
||||
"test_csv = \"test.csv\"\n",
|
||||
"\n",
|
||||
"# Create the data module\n",
|
||||
"dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
|
||||
"dm.setup()\n",
|
||||
"\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5d3d684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0494596",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### next steps:\n",
|
||||
"* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
|
||||
"\n",
|
||||
"#### what we've done:\n",
|
||||
"* Change the data loading so it's more general, meaning on the fly loading from disk\n",
|
||||
"* add torch.compile\n",
|
||||
"* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
|
||||
"* add tensorboard visualization\n",
|
||||
"* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
|
||||
"* 2. Create an inference step, send in news article -> get summary, check that it works\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0f9b71ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
|
||||
class cnn_dailymail(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=512):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
|
||||
# if the csv_file is "train.csv" then only take out 10% of the data. make sure to reset indices etc
|
||||
#if csv_file == "train.csv":
|
||||
# self.data = self.data.sample(frac=0.05, random_state=42).reset_index(drop=True)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
article = self.data.loc[idx, "article"]
|
||||
highlights = self.data.loc[idx, "highlights"]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
article,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
targets = self.tokenizer(
|
||||
highlights,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": inputs["input_ids"].squeeze(),
|
||||
"attention_mask": inputs["attention_mask"].squeeze(),
|
||||
"labels": targets["input_ids"].squeeze(),
|
||||
}
|
||||
|
||||
|
||||
class MyDataModule(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512
|
||||
):
|
||||
super().__init__()
|
||||
self.train_csv = train_csv
|
||||
self.val_csv = val_csv
|
||||
self.test_csv = test_csv
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
|
||||
def setup(self, stage=None):
|
||||
if stage in ("fit", None):
|
||||
self.train_dataset = cnn_dailymail(
|
||||
self.train_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
self.val_dataset = cnn_dailymail(
|
||||
self.val_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
if stage in ("test", None):
|
||||
self.test_dataset = cnn_dailymail(
|
||||
self.test_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
num_workers=6,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1
|
||||
)
|
||||
@@ -1,470 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04530b1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the LightningDataModule\n",
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" def prepare_data(self):\n",
|
||||
" # Download and preprocess the data\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
" \n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" # Load and preprocess the data\n",
|
||||
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
" self.train_ds = train_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.val_ds = val_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def preprocess_function(self, batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def validation_epoch_end(self, outputs):\n",
|
||||
" decoded_preds = []\n",
|
||||
" decoded_labels = []\n",
|
||||
" for output in outputs:\n",
|
||||
" logits = output['logits']\n",
|
||||
" labels = output['labels']\n",
|
||||
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
|
||||
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"\n",
|
||||
" 0%| | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
|
||||
" 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
|
||||
" 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
|
||||
" 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
|
||||
" 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
|
||||
" 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
|
||||
" 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
|
||||
" 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
|
||||
" 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
|
||||
" 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
|
||||
" 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
|
||||
" 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
|
||||
" 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
|
||||
" 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
|
||||
" 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
|
||||
" 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
|
||||
" 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
|
||||
" 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
|
||||
" 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
|
||||
" 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
|
||||
" 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
|
||||
" 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
|
||||
" 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
|
||||
" 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
|
||||
" 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
|
||||
" 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
|
||||
" 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
|
||||
" 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
|
||||
" 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
|
||||
" 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
|
||||
" 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
|
||||
" 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
|
||||
" 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
|
||||
" 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
|
||||
" 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
|
||||
" 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
|
||||
" 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
|
||||
" 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
|
||||
" 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
|
||||
" 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
|
||||
" 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
|
||||
" 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
|
||||
" 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
|
||||
" 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
|
||||
" 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
|
||||
" 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
|
||||
" 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
|
||||
" 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
|
||||
" 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
|
||||
" 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
|
||||
" 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
|
||||
" 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
|
||||
" 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
|
||||
" 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
|
||||
" 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
|
||||
" 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
|
||||
" 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
|
||||
" 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
|
||||
" 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
|
||||
" 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
|
||||
" 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
|
||||
" 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
|
||||
" 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
|
||||
" 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
|
||||
" 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
|
||||
" 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
|
||||
" 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
|
||||
" 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
|
||||
" 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
|
||||
" 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
|
||||
" 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
|
||||
" 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
|
||||
" 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
|
||||
" 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
|
||||
" 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
|
||||
" 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
|
||||
" 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
|
||||
" 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
|
||||
" 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
|
||||
" 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
|
||||
" 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
|
||||
" 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
|
||||
" 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
|
||||
" 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
|
||||
" 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
|
||||
" 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
|
||||
" 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
|
||||
" 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
|
||||
" 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
|
||||
" 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
|
||||
" 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
|
||||
" 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
|
||||
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
|
||||
" 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
|
||||
" 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
|
||||
" 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
|
||||
" 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
|
||||
" 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
|
||||
" 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
|
||||
" 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
|
||||
" 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
|
||||
" 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
|
||||
" 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
|
||||
" 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
|
||||
" 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
|
||||
" 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
|
||||
" 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
|
||||
" 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
|
||||
" 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
|
||||
" 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
|
||||
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
|
||||
" 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
|
||||
" 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
|
||||
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
|
||||
" 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
|
||||
" 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
|
||||
" 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
|
||||
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
|
||||
" 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
|
||||
" 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
|
||||
" 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
|
||||
"\n",
|
||||
" 0%| | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
|
||||
" 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
|
||||
" 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
|
||||
" 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"0 | model | T5ForConditionalGeneration | 60.5 M\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"60.5 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"60.5 M Total params\n",
|
||||
"242.026 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'list' object has no attribute 'size'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_float32_matmul_precision(\"medium\")\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
|
||||
"dm = MyDataModule(batch_size=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aa7b1ab0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Recap of what we did:\n",
|
||||
"* Finetuned T5-Small on DailyCNN (summarize news articles) using HF Trainer and data loading\n",
|
||||
"* Converted to Lightning code \n",
|
||||
"\n",
|
||||
"### To do next:\n",
|
||||
"* Make it work with the evaluation somethings wrong now, don't think it's a big issue\n",
|
||||
"* Clean up the code a bit\n",
|
||||
"* Compare it with HF, add predict function, modify data loading so it's from scratch / more general way of doing it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,237 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5372055b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jupyterthemes.stylefx import set_nb_theme\n",
|
||||
"set_nb_theme('chesterish')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11214a4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f45eb6b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForMaskedLM,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoModelForTokenClassification,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" pipeline,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b2d26af4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "363045f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def preprocess_function(batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"# Load the dataset\n",
|
||||
"train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
|
||||
"val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
"train_ds = train_data.map(\n",
|
||||
" preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=256, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_ds = val_data.map(\n",
|
||||
" preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=256, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0d58818f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.num_training_steps = num_training_steps\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" \n",
|
||||
" loss\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=1e-5,\n",
|
||||
" per_device_train_batch_size=16,\n",
|
||||
" per_device_eval_batch_size=16,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=1e-4,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
" generation_max_length=128,\n",
|
||||
" predict_with_generate=True,\n",
|
||||
" logging_steps=100,\n",
|
||||
" gradient_accumulation_steps=1,\n",
|
||||
" fp16=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_data,\n",
|
||||
" eval_dataset=val_data,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5148159b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Steps:\n",
|
||||
"1. Rewrite code to be more general\n",
|
||||
"\n",
|
||||
"a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
|
||||
"\n",
|
||||
"b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95e33e40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4c0348c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,644 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "7d5e92c6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
|
||||
"example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
|
||||
"\n",
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"print(ner_entity_results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "bf67ee76",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Turtle s Original Cara mel Chocolate P eca n luster\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"\n",
|
||||
"# Initialize the entity words list with an empty string\n",
|
||||
"entity_words = [\"\"]\n",
|
||||
"\n",
|
||||
"# Loop through each dictionary in the list and extract the entity word\n",
|
||||
"for result in ner_entity_results:\n",
|
||||
" if result[\"entity\"] == \"B-FOOD\":\n",
|
||||
" entity_words.append(result[\"word\"])\n",
|
||||
" elif result[\"entity\"] == \"I-FOOD\":\n",
|
||||
" entity_words[-1] += \" \" + result[\"word\"]\n",
|
||||
"\n",
|
||||
"# Remove any remaining ## symbols and extra spaces\n",
|
||||
"entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
|
||||
"\n",
|
||||
"# Join the entity words into a single string\n",
|
||||
"output = \" \".join(entity_words)\n",
|
||||
"\n",
|
||||
"print(output)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc8e5ea0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"print(torch.cuda.is_available())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8a1e039",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ad73024",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = pipeline(\"zero-shot-classification\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "04f7e02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier(\n",
|
||||
" \"This is a course about the Transformers library\",\n",
|
||||
" candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6fb246c2",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4e174f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# Define input text and pre-trained model checkpoint\n",
|
||||
"text = \"My name is wolfgang and I live in berlin\"\n",
|
||||
"checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
|
||||
"\n",
|
||||
"# Instantiate tokenizer and encode input text\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# Instantiate model and generate output\n",
|
||||
"model = AutoModel.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"print(outputs[0].shape)\n",
|
||||
"\n",
|
||||
"# Instantiate token classification model and generate predictions\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
|
||||
"print(predictions)\n",
|
||||
"print(model.config.id2label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8212bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# prepare input\n",
|
||||
"text = \"Replace me by any text you'd like.\"\n",
|
||||
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"# forward pass\n",
|
||||
"output = model(**encoded_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "314cba41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"# Load the pre-trained tokenizer and model\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# Define the input sentence with a masked token\n",
|
||||
"text = \"I want to <mask> a new car tomorrow.\"\n",
|
||||
"\n",
|
||||
"# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
|
||||
"encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"print(output.logits.shape)\n",
|
||||
"print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
|
||||
"\n",
|
||||
"# Extract the predicted probabilities for the masked token\n",
|
||||
"predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
|
||||
"predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
|
||||
"\n",
|
||||
"# Get the top-k most probable predictions for the masked token\n",
|
||||
"k = 5\n",
|
||||
"top_k = torch.topk(predicted_probabilities, k)\n",
|
||||
"for i in range(k):\n",
|
||||
" token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
|
||||
" score = top_k.values[i].item()\n",
|
||||
" print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6187e77e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
|
||||
"\n",
|
||||
"sequences = [\n",
|
||||
" \"Using a Transformer network is simple\",\n",
|
||||
" \"The quick brown fox jumps over the lazy dog\",\n",
|
||||
" \"To be or not to be, that is the question\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
|
||||
"inputs = tokenizer(\n",
|
||||
" sequences,\n",
|
||||
" padding=True,\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the resulting input IDs and attention masks\n",
|
||||
"print(inputs['input_ids'])\n",
|
||||
"print(inputs['attention_mask'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc259c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43466db6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Huggingface:\n",
|
||||
"\n",
|
||||
"1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
|
||||
"2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
|
||||
"3. Load dataset\n",
|
||||
"4. How to finetune\n",
|
||||
"5. How to evaluate\n",
|
||||
"6. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97c474f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ed5d8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
|
||||
"\n",
|
||||
"# Same as before\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
||||
"sequences = [\n",
|
||||
" \"I've been waiting for a HuggingFace course my whole life.\",\n",
|
||||
" \"This course is amazing!\",\n",
|
||||
"]\n",
|
||||
"batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# This is new\n",
|
||||
"batch[\"labels\"] = torch.tensor([1, 1])\n",
|
||||
"\n",
|
||||
"optimizer = AdamW(model.parameters())\n",
|
||||
"loss = model(**batch).loss\n",
|
||||
"loss.backward()\n",
|
||||
"optimizer.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c598624f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd296227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_dataset = raw_datasets[\"train\"]\n",
|
||||
"raw_train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e462947a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
|
||||
"\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"\n",
|
||||
"def tokenize_function(example):\n",
|
||||
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
|
||||
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import evaluate\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e2795dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3af29cd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "817f644e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42819a6c",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb5986b0",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"batch_size=32\n",
|
||||
"\n",
|
||||
"# Define the generator function to preprocess the data in batches\n",
|
||||
"def preprocess_generator(examples):\n",
|
||||
" for i in range(0, len(examples[\"article\"]), batch_size):\n",
|
||||
" batch = examples[\"article\"][i:i+batch_size]\n",
|
||||
" targets = examples[\"highlights\"][i:i+batch_size]\n",
|
||||
" model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" yield model_inputs\n",
|
||||
"\n",
|
||||
"def preprocess_function(examples):\n",
|
||||
" articles = [ex for ex in examples[\"article\"]]\n",
|
||||
" summaries = [ex for ex in examples[\"highlights\"]]\n",
|
||||
"\n",
|
||||
" model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" \n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" return model_inputs\n",
|
||||
" \n",
|
||||
"# Load the dataset\n",
|
||||
"raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
|
||||
"preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
|
||||
"\n",
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" evaluation_strategy = \"epoch\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=1000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_ds,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7d62583e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d310a7b3",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessed_datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99d422cc",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" \n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=preprocessed_datasets[\"train\"],\n",
|
||||
" eval_dataset=preprocessed_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a5e97b57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install nltk\n",
|
||||
"!pip install rouge_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "558c3e66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Goal:\n",
|
||||
"\n",
|
||||
"1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
|
||||
"* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
|
||||
"\n",
|
||||
"2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "624d49ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, DataCollatorWithPadding
|
||||
from transformers import Trainer
|
||||
|
||||
raw_datasets = load_dataset("glue", "mrpc")
|
||||
checkpoint = "bert-base-uncased"
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
|
||||
|
||||
def tokenize_function(example):
|
||||
return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
|
||||
|
||||
|
||||
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
|
||||
from transformers import TrainingArguments
|
||||
training_args = TrainingArguments("test-trainer")
|
||||
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
metric = evaluate.load("glue", "mrpc")
|
||||
logits, labels = eval_preds
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -1,130 +0,0 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import T5Config, T5ForConditionalGeneration
|
||||
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
Seq2SeqTrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
)
|
||||
|
||||
|
||||
class MyLightningModule(pl.LightningModule):
|
||||
def __init__(self, model_name, learning_rate, weight_decay):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Load the pre-trained model and tokenizer
|
||||
#self.model = torch.compile(
|
||||
# AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||
#)
|
||||
|
||||
# Create a T5-small configuration
|
||||
config = T5Config.from_pretrained("t5-small")
|
||||
|
||||
# Initialize the T5 model with random weights
|
||||
self.model = torch.compile(T5ForConditionalGeneration(config))
|
||||
|
||||
# Load the ROUGE metric
|
||||
self.metric = load_metric("rouge")
|
||||
self.logits = []
|
||||
self.labels = []
|
||||
|
||||
def forward(self, input_ids, attention_mask, labels=None):
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
)
|
||||
return output.loss, output.logits
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
loss, logits = self(input_ids, attention_mask, labels)
|
||||
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
||||
return {"loss": loss, "logits": logits}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
loss, logits = self(input_ids, attention_mask, labels)
|
||||
self.log("val_loss", loss, on_epoch=True, on_step=False)
|
||||
|
||||
# add logits and labels to instance attributes, but make sure to detach them
|
||||
# from the computational graph first
|
||||
self.logits.append(logits.argmax(dim=-1).detach().cpu())
|
||||
self.labels.append(labels.detach().cpu())
|
||||
return {"loss": loss, "logits": logits, "labels": labels}
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Concatenate tensors in logits and labels lists
|
||||
pred_token_ids = torch.cat(self.logits, dim=0)
|
||||
true_labels = torch.cat(self.labels, dim=0)
|
||||
|
||||
# Decode predictions and labels using the saved instance attributes
|
||||
decoded_preds = self.tokenizer.batch_decode(
|
||||
pred_token_ids, skip_special_tokens=True
|
||||
)
|
||||
decoded_labels = self.tokenizer.batch_decode(
|
||||
true_labels, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Compute ROUGE scores
|
||||
scores = self.metric.compute(
|
||||
predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"]
|
||||
)["rouge1"].mid
|
||||
|
||||
self.log("rouge1_precision", scores.precision, prog_bar=True)
|
||||
self.log("rouge1_recall", scores.recall, prog_bar=True)
|
||||
self.log("rouge1_fmeasure", scores.fmeasure, prog_bar=True)
|
||||
|
||||
# Clear logits and labels instance attributes for the next validation epoch
|
||||
self.logits.clear()
|
||||
self.labels.clear()
|
||||
|
||||
def predict(self, article: str, max_input_length: int = 512, max_output_length: int = 150) -> str:
|
||||
# Set the model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
# Tokenize the input article
|
||||
inputs = self.tokenizer(
|
||||
article,
|
||||
max_length=max_input_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move the input tensors to the same device as the model
|
||||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||
|
||||
# Generate summary
|
||||
with torch.no_grad():
|
||||
output = self.model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_length=max_output_length,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
|
||||
# Decode and return the summary
|
||||
summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
return summary
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
l = ["cat", "dog"]
|
||||
sentence = "The quick brown fox jumps over the lazy dog"
|
||||
@@ -1,67 +0,0 @@
|
||||
from dataset import MyDataModule
|
||||
from model import MyLightningModule
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
Seq2SeqTrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
)
|
||||
import torch
|
||||
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define the checkpoint callback
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
dirpath="checkpoints",
|
||||
filename="my_model-{epoch:02d}-{val_loss:.2f}",
|
||||
save_top_k=-1,
|
||||
every_n_epochs=1,
|
||||
verbose=True,
|
||||
)
|
||||
logger = TensorBoardLogger("tb_logs", name="t5_dailymail")
|
||||
|
||||
model_name = "t5-small"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# File paths
|
||||
train_csv = "train.csv"
|
||||
val_csv = "validation.csv"
|
||||
test_csv = "test.csv"
|
||||
|
||||
# Create the data module
|
||||
dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=32)
|
||||
dm.setup()
|
||||
|
||||
model = MyLightningModule(
|
||||
model_name="t5-small", learning_rate=1e-4, weight_decay=1e-5
|
||||
)
|
||||
|
||||
|
||||
#checkpoint_path = "checkpoints/curr.ckpt"
|
||||
#checkpoint = torch.load(checkpoint_path)
|
||||
#model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
devices=[0, 1],
|
||||
max_epochs=10,
|
||||
precision=16,
|
||||
logger=logger,
|
||||
callbacks=[checkpoint_callback],
|
||||
log_every_n_steps=10,
|
||||
)
|
||||
trainer.fit(model, dm)
|
||||
trainer.validate(model, dm)
|
||||
|
||||
#example = """Former President Donald Trump claims in a social media post that he will be arrested next week. The claim comes while a New York prosecutor considers charging Trump in connection with hush money paid to adult film actress Stormy Daniels but there has been no official announcement of any plans for an indictment. What we know about Trump possibly facing criminal indictment in New York City. Trump has been entangled in several criminal investigations but the case related to Daniels is the longest-running of all of them, reaching back to 2016. On his platform Truth Social on Saturday morning, Trump cited "illegal leaks" that he will be arrested Tuesday and he called for protests. Trump, who is running for president in 2024, also defended himself, saying that he has not committed a crime — though he did not disclose what he expects to be charged with — and he accused the Manhattan District Attorney's Office of being "corrupt & highly political.". 'I'M BACK!' Trump posts on Facebook, YouTube for first time in two years. The Manhattan District Attorney's Office declined to comment on whether it will soon be pursing an arrest warrant for Trump. But the Associated Press reported that law enforcement officials in New York are discussing security preparations in anticipation that Trump may be indicted in coming weeks. If it does occur, Trump would become the first former president to be indicted in U.S. history."""
|
||||
#print(len(tokenizer(example)["input_ids"]))
|
||||
#summary = model.predict(example)
|
||||
#print(summary)
|
||||
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
# Sample input batched matrices dimensions.
|
||||
n_batch, n, m, k = 10, 64, 128, 32
|
||||
|
||||
# Setup code as a string to be reused across benchmarks
|
||||
setup_code = (
|
||||
f'import torch; '
|
||||
f'x = torch.randn({n_batch}, {n}, {m}); '
|
||||
f'y = torch.randn({n_batch}, {m}, {k})')
|
||||
|
||||
# Number of threads from torch, reused in all timers.
|
||||
num_threads = torch.get_num_threads()
|
||||
|
||||
# A list of methods and their stmt strings for the benchmark
|
||||
methods = [
|
||||
('bmm', 'torch.bmm(x, y)'),
|
||||
('matmul', 'torch.matmul(x, y)'),
|
||||
('einsum', "torch.einsum('bnm,bmk->bnk', x, y)"),
|
||||
]
|
||||
|
||||
# Run each benchmark for a number of times to ensure measurement stability
|
||||
num_runs = 100
|
||||
|
||||
# Create benchmark objects and run them, collecting the results.
|
||||
results = [
|
||||
benchmark.Timer(
|
||||
stmt=stmt,
|
||||
setup=setup_code,
|
||||
num_threads=num_threads,
|
||||
label="Batched Matrix Multiplication",
|
||||
sub_label=f"Method: {label}",
|
||||
description=f"{n_batch}x{n}x{m}x{k}",
|
||||
).timeit(num_runs)
|
||||
for label, stmt in methods
|
||||
]
|
||||
|
||||
# Group the results into a Compare object and print the results table.
|
||||
benchmark.Compare(results).print()
|
||||
14
ML/Pytorch/linkedin_posts/tensor_operations/cat_vs_stack.py
Normal file
14
ML/Pytorch/linkedin_posts/tensor_operations/cat_vs_stack.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
# ===== USING torch.cat =====
|
||||
# torch.cat concatenates tensors along an existing axis.
|
||||
t1 = torch.randn(2, 3)
|
||||
t2 = torch.randn(2, 3)
|
||||
|
||||
cat_dim0 = torch.cat((t1, t2), dim=0) # shape: (4, 3)
|
||||
cat_dim1 = torch.cat((t1, t2), dim=1) # shape: (2, 6)
|
||||
|
||||
# ===== USING torch.stack =====
|
||||
# torch.stack concatenates tensors along a new axis, not existing one.
|
||||
stack_dim0 = torch.stack((t1, t2), dim=0) # shape: 2x2x3
|
||||
stack_dim2 = torch.stack((t1, t2), dim=2) # shape: 2x3x2
|
||||
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def measure_performance(matrix_type, n, device, num_runs=100):
|
||||
matrix = torch.randn(n, n, dtype=matrix_type, device=device)
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# Warming up
|
||||
for _ in range(5):
|
||||
torch.matmul(matrix, matrix)
|
||||
|
||||
# Measure performance
|
||||
start_event.record()
|
||||
for _ in range(num_runs):
|
||||
torch.matmul(matrix, matrix)
|
||||
end_event.record()
|
||||
|
||||
# Synchronizes events to ensure the time is measured correctly
|
||||
torch.cuda.synchronize()
|
||||
return start_event.elapsed_time(end_event) / num_runs
|
||||
|
||||
n = 2**11
|
||||
num_runs = 100
|
||||
|
||||
# Dictionary to store execution time for different data types
|
||||
execution_times = {
|
||||
'float16': measure_performance(torch.float16, n, device, num_runs),
|
||||
'float32': measure_performance(torch.float32, n, device, num_runs),
|
||||
'float64': measure_performance(torch.float64, n, device, num_runs),
|
||||
}
|
||||
|
||||
print(f'Execution time ({device}):')
|
||||
for dtype, time in execution_times.items():
|
||||
print(f'{dtype}: {time:.6f} ms')
|
||||
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
# Create two random matrices
|
||||
a = torch.rand(2**10, 2**10)
|
||||
b = torch.rand(2**10, 2**10)
|
||||
|
||||
# using mm (note only works for 2D tensors)
|
||||
c = torch.mm(a, b)
|
||||
|
||||
# using matmul
|
||||
c = torch.matmul(a, b)
|
||||
|
||||
# using @ operator (note exact same as matmul)
|
||||
c = a @ b
|
||||
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
# Create a 3D tensor with shape (2, 3, 4)
|
||||
tensor = torch.rand(2, 3, 4)
|
||||
|
||||
# We'll swap the first and second dimension using torch.transpose
|
||||
transposed = tensor.transpose(0, 1) # shape: 3x2x4
|
||||
|
||||
# Now let's permute the tensor dimensions with torch.permute
|
||||
permuted = tensor.permute(2, 0, 1) # shape: 4x2x3
|
||||
permuted_like_transpose = tensor.permute(1, 0, 2) # shape: 3x2x4
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,153 @@
|
||||
Summary
|
||||
=======
|
||||
|
||||
This dataset (ml-latest-small) describes 5-star rating and free-text tagging activity from [MovieLens](http://movielens.org), a movie recommendation service. It contains 100836 ratings and 3683 tag applications across 9742 movies. These data were created by 610 users between March 29, 1996 and September 24, 2018. This dataset was generated on September 26, 2018.
|
||||
|
||||
Users were selected at random for inclusion. All selected users had rated at least 20 movies. No demographic information is included. Each user is represented by an id, and no other information is provided.
|
||||
|
||||
The data are contained in the files `links.csv`, `movies.csv`, `ratings.csv` and `tags.csv`. More details about the contents and use of all these files follows.
|
||||
|
||||
This is a *development* dataset. As such, it may change over time and is not an appropriate dataset for shared research results. See available *benchmark* datasets if that is your intent.
|
||||
|
||||
This and other GroupLens data sets are publicly available for download at <http://grouplens.org/datasets/>.
|
||||
|
||||
|
||||
Usage License
|
||||
=============
|
||||
|
||||
Neither the University of Minnesota nor any of the researchers involved can guarantee the correctness of the data, its suitability for any particular purpose, or the validity of results based on the use of the data set. The data set may be used for any research purposes under the following conditions:
|
||||
|
||||
* The user may not state or imply any endorsement from the University of Minnesota or the GroupLens Research Group.
|
||||
* The user must acknowledge the use of the data set in publications resulting from the use of the data set (see below for citation information).
|
||||
* The user may redistribute the data set, including transformations, so long as it is distributed under these same license conditions.
|
||||
* The user may not use this information for any commercial or revenue-bearing purposes without first obtaining permission from a faculty member of the GroupLens Research Project at the University of Minnesota.
|
||||
* The executable software scripts are provided "as is" without warranty of any kind, either expressed or implied, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The entire risk as to the quality and performance of them is with you. Should the program prove defective, you assume the cost of all necessary servicing, repair or correction.
|
||||
|
||||
In no event shall the University of Minnesota, its affiliates or employees be liable to you for any damages arising out of the use or inability to use these programs (including but not limited to loss of data or data being rendered inaccurate).
|
||||
|
||||
If you have any further questions or comments, please email <grouplens-info@umn.edu>
|
||||
|
||||
|
||||
Citation
|
||||
========
|
||||
|
||||
To acknowledge use of the dataset in publications, please cite the following paper:
|
||||
|
||||
> F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Transactions on Interactive Intelligent Systems (TiiS) 5, 4: 19:1–19:19. <https://doi.org/10.1145/2827872>
|
||||
|
||||
|
||||
Further Information About GroupLens
|
||||
===================================
|
||||
|
||||
GroupLens is a research group in the Department of Computer Science and Engineering at the University of Minnesota. Since its inception in 1992, GroupLens's research projects have explored a variety of fields including:
|
||||
|
||||
* recommender systems
|
||||
* online communities
|
||||
* mobile and ubiquitious technologies
|
||||
* digital libraries
|
||||
* local geographic information systems
|
||||
|
||||
GroupLens Research operates a movie recommender based on collaborative filtering, MovieLens, which is the source of these data. We encourage you to visit <http://movielens.org> to try it out! If you have exciting ideas for experimental work to conduct on MovieLens, send us an email at <grouplens-info@cs.umn.edu> - we are always interested in working with external collaborators.
|
||||
|
||||
|
||||
Content and Use of Files
|
||||
========================
|
||||
|
||||
Formatting and Encoding
|
||||
-----------------------
|
||||
|
||||
The dataset files are written as [comma-separated values](http://en.wikipedia.org/wiki/Comma-separated_values) files with a single header row. Columns that contain commas (`,`) are escaped using double-quotes (`"`). These files are encoded as UTF-8. If accented characters in movie titles or tag values (e.g. Misérables, Les (1995)) display incorrectly, make sure that any program reading the data, such as a text editor, terminal, or script, is configured for UTF-8.
|
||||
|
||||
|
||||
User Ids
|
||||
--------
|
||||
|
||||
MovieLens users were selected at random for inclusion. Their ids have been anonymized. User ids are consistent between `ratings.csv` and `tags.csv` (i.e., the same id refers to the same user across the two files).
|
||||
|
||||
|
||||
Movie Ids
|
||||
---------
|
||||
|
||||
Only movies with at least one rating or tag are included in the dataset. These movie ids are consistent with those used on the MovieLens web site (e.g., id `1` corresponds to the URL <https://movielens.org/movies/1>). Movie ids are consistent between `ratings.csv`, `tags.csv`, `movies.csv`, and `links.csv` (i.e., the same id refers to the same movie across these four data files).
|
||||
|
||||
|
||||
Ratings Data File Structure (ratings.csv)
|
||||
-----------------------------------------
|
||||
|
||||
All ratings are contained in the file `ratings.csv`. Each line of this file after the header row represents one rating of one movie by one user, and has the following format:
|
||||
|
||||
userId,movieId,rating,timestamp
|
||||
|
||||
The lines within this file are ordered first by userId, then, within user, by movieId.
|
||||
|
||||
Ratings are made on a 5-star scale, with half-star increments (0.5 stars - 5.0 stars).
|
||||
|
||||
Timestamps represent seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970.
|
||||
|
||||
|
||||
Tags Data File Structure (tags.csv)
|
||||
-----------------------------------
|
||||
|
||||
All tags are contained in the file `tags.csv`. Each line of this file after the header row represents one tag applied to one movie by one user, and has the following format:
|
||||
|
||||
userId,movieId,tag,timestamp
|
||||
|
||||
The lines within this file are ordered first by userId, then, within user, by movieId.
|
||||
|
||||
Tags are user-generated metadata about movies. Each tag is typically a single word or short phrase. The meaning, value, and purpose of a particular tag is determined by each user.
|
||||
|
||||
Timestamps represent seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970.
|
||||
|
||||
|
||||
Movies Data File Structure (movies.csv)
|
||||
---------------------------------------
|
||||
|
||||
Movie information is contained in the file `movies.csv`. Each line of this file after the header row represents one movie, and has the following format:
|
||||
|
||||
movieId,title,genres
|
||||
|
||||
Movie titles are entered manually or imported from <https://www.themoviedb.org/>, and include the year of release in parentheses. Errors and inconsistencies may exist in these titles.
|
||||
|
||||
Genres are a pipe-separated list, and are selected from the following:
|
||||
|
||||
* Action
|
||||
* Adventure
|
||||
* Animation
|
||||
* Children's
|
||||
* Comedy
|
||||
* Crime
|
||||
* Documentary
|
||||
* Drama
|
||||
* Fantasy
|
||||
* Film-Noir
|
||||
* Horror
|
||||
* Musical
|
||||
* Mystery
|
||||
* Romance
|
||||
* Sci-Fi
|
||||
* Thriller
|
||||
* War
|
||||
* Western
|
||||
* (no genres listed)
|
||||
|
||||
|
||||
Links Data File Structure (links.csv)
|
||||
---------------------------------------
|
||||
|
||||
Identifiers that can be used to link to other sources of movie data are contained in the file `links.csv`. Each line of this file after the header row represents one movie, and has the following format:
|
||||
|
||||
movieId,imdbId,tmdbId
|
||||
|
||||
movieId is an identifier for movies used by <https://movielens.org>. E.g., the movie Toy Story has the link <https://movielens.org/movies/1>.
|
||||
|
||||
imdbId is an identifier for movies used by <http://www.imdb.com>. E.g., the movie Toy Story has the link <http://www.imdb.com/title/tt0114709/>.
|
||||
|
||||
tmdbId is an identifier for movies used by <https://www.themoviedb.org>. E.g., the movie Toy Story has the link <https://www.themoviedb.org/movie/862>.
|
||||
|
||||
Use of the resources listed above is subject to the terms of each provider.
|
||||
|
||||
|
||||
Cross-Validation
|
||||
----------------
|
||||
|
||||
Prior versions of the MovieLens dataset included either pre-computed cross-folds or scripts to perform this computation. We no longer bundle either of these features with the dataset, since most modern toolkits provide this as a built-in feature. If you wish to learn about standard approaches to cross-fold computation in the context of recommender systems evaluation, see [LensKit](http://lenskit.org) for tools, documentation, and open-source code examples.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,503 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Non-Personalized Recommender Systems: Popularity Based"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if os.path.exists('movielens_small.zip'):\n",
|
||||
" !wget https://files.grouplens.org/datasets/movielens/ml-latest-small.zip \n",
|
||||
" !unzip ml-latest-small.zip\n",
|
||||
" !rm ml-latest-small.zip\n",
|
||||
" !mv ml-latest-small movielens_small"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Damped Mean\n",
|
||||
"\n",
|
||||
"$$ s(i) = \\frac{\\Sigma_{u \\in U_i} r_i + a \\times \\mu}{|U_i| + a} $$\n",
|
||||
"\n",
|
||||
"Where:\n",
|
||||
"- $ s(i) $: The damped mean rating for item $ i $.\n",
|
||||
"- $ \\Sigma_{u \\in U_i} r_i $: Sum of the ratings for item $ i $.\n",
|
||||
"- $ a $: Damping factor, a value that determines the extent of smoothing.\n",
|
||||
"- $ \\mu $: Global mean rating across all items.\n",
|
||||
"- $ |U_i| $: Total number of ratings for item $ i $.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "937dd4ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"def load_data():\n",
|
||||
" # Load the MovieLens data\n",
|
||||
" movies_df = pd.read_csv(\"movielens_small/movies.csv\")\n",
|
||||
" ratings_df = pd.read_csv(\"movielens_small/ratings.csv\")\n",
|
||||
" return movies_df, ratings_df\n",
|
||||
"\n",
|
||||
"def calculate_popularity(movies_df, ratings_df, damping_factor=5):\n",
|
||||
" # Calculate the number of ratings, mean rating, and sum of ratings for each movie\n",
|
||||
" num_ratings = ratings_df.groupby(\"movieId\")[\"rating\"].count()\n",
|
||||
" mean_rating = ratings_df.groupby(\"movieId\")[\"rating\"].mean()\n",
|
||||
" global_mean = ratings_df[\"rating\"].mean()\n",
|
||||
" \n",
|
||||
" # Calculate the damped mean rating for each movie\n",
|
||||
" damped_numerator = num_ratings * mean_rating + damping_factor * global_mean\n",
|
||||
" damped_denominator = num_ratings + damping_factor\n",
|
||||
" damped_mean_rating = damped_numerator / damped_denominator\n",
|
||||
" \n",
|
||||
" # Add the popularity data to the movie data\n",
|
||||
" movies_df['num_ratings'] = movies_df['movieId'].map(num_ratings)\n",
|
||||
" movies_df['mean_rating'] = movies_df['movieId'].map(mean_rating)\n",
|
||||
" movies_df['damped_mean_rating'] = movies_df['movieId'].map(damped_mean_rating)\n",
|
||||
" return movies_df\n",
|
||||
"\n",
|
||||
"movies_df, ratings_df = load_data()\n",
|
||||
"movies_df = calculate_popularity(movies_df, ratings_df, damping_factor=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's see how using num_ratings compares to mean rating & damped mean rating."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 198,
|
||||
"id": "7e649c6f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>movieId</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>genres</th>\n",
|
||||
" <th>num_ratings</th>\n",
|
||||
" <th>mean_rating</th>\n",
|
||||
" <th>damped_mean_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>314</th>\n",
|
||||
" <td>356</td>\n",
|
||||
" <td>Forrest Gump (1994)</td>\n",
|
||||
" <td>Comedy|Drama|Romance|War</td>\n",
|
||||
" <td>329.0</td>\n",
|
||||
" <td>4.164134</td>\n",
|
||||
" <td>4.144589</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>277</th>\n",
|
||||
" <td>318</td>\n",
|
||||
" <td>Shawshank Redemption, The (1994)</td>\n",
|
||||
" <td>Crime|Drama</td>\n",
|
||||
" <td>317.0</td>\n",
|
||||
" <td>4.429022</td>\n",
|
||||
" <td>4.400659</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>257</th>\n",
|
||||
" <td>296</td>\n",
|
||||
" <td>Pulp Fiction (1994)</td>\n",
|
||||
" <td>Comedy|Crime|Drama|Thriller</td>\n",
|
||||
" <td>307.0</td>\n",
|
||||
" <td>4.197068</td>\n",
|
||||
" <td>4.175128</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>510</th>\n",
|
||||
" <td>593</td>\n",
|
||||
" <td>Silence of the Lambs, The (1991)</td>\n",
|
||||
" <td>Crime|Horror|Thriller</td>\n",
|
||||
" <td>279.0</td>\n",
|
||||
" <td>4.161290</td>\n",
|
||||
" <td>4.138462</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1939</th>\n",
|
||||
" <td>2571</td>\n",
|
||||
" <td>Matrix, The (1999)</td>\n",
|
||||
" <td>Action|Sci-Fi|Thriller</td>\n",
|
||||
" <td>278.0</td>\n",
|
||||
" <td>4.192446</td>\n",
|
||||
" <td>4.168457</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" movieId title genres \n",
|
||||
"314 356 Forrest Gump (1994) Comedy|Drama|Romance|War \\\n",
|
||||
"277 318 Shawshank Redemption, The (1994) Crime|Drama \n",
|
||||
"257 296 Pulp Fiction (1994) Comedy|Crime|Drama|Thriller \n",
|
||||
"510 593 Silence of the Lambs, The (1991) Crime|Horror|Thriller \n",
|
||||
"1939 2571 Matrix, The (1999) Action|Sci-Fi|Thriller \n",
|
||||
"\n",
|
||||
" num_ratings mean_rating damped_mean_rating \n",
|
||||
"314 329.0 4.164134 4.144589 \n",
|
||||
"277 317.0 4.429022 4.400659 \n",
|
||||
"257 307.0 4.197068 4.175128 \n",
|
||||
"510 279.0 4.161290 4.138462 \n",
|
||||
"1939 278.0 4.192446 4.168457 "
|
||||
]
|
||||
},
|
||||
"execution_count": 198,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"movies_df.sort_values(by=\"num_ratings\", ascending=False).head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 204,
|
||||
"id": "c6ef332e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>movieId</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>genres</th>\n",
|
||||
" <th>num_ratings</th>\n",
|
||||
" <th>mean_rating</th>\n",
|
||||
" <th>damped_mean_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>7656</th>\n",
|
||||
" <td>88448</td>\n",
|
||||
" <td>Paper Birds (Pájaros de papel) (2010)</td>\n",
|
||||
" <td>Comedy|Drama</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3.637779</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8107</th>\n",
|
||||
" <td>100556</td>\n",
|
||||
" <td>Act of Killing, The (2012)</td>\n",
|
||||
" <td>Documentary</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3.637779</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9083</th>\n",
|
||||
" <td>143031</td>\n",
|
||||
" <td>Jump In! (2007)</td>\n",
|
||||
" <td>Comedy|Drama|Romance</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3.637779</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9094</th>\n",
|
||||
" <td>143511</td>\n",
|
||||
" <td>Human (2015)</td>\n",
|
||||
" <td>Documentary</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3.637779</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9096</th>\n",
|
||||
" <td>143559</td>\n",
|
||||
" <td>L.A. Slasher (2015)</td>\n",
|
||||
" <td>Comedy|Crime|Fantasy</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3.637779</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" movieId title genres \n",
|
||||
"7656 88448 Paper Birds (Pájaros de papel) (2010) Comedy|Drama \\\n",
|
||||
"8107 100556 Act of Killing, The (2012) Documentary \n",
|
||||
"9083 143031 Jump In! (2007) Comedy|Drama|Romance \n",
|
||||
"9094 143511 Human (2015) Documentary \n",
|
||||
"9096 143559 L.A. Slasher (2015) Comedy|Crime|Fantasy \n",
|
||||
"\n",
|
||||
" num_ratings mean_rating damped_mean_rating \n",
|
||||
"7656 1.0 5.0 3.637779 \n",
|
||||
"8107 1.0 5.0 3.637779 \n",
|
||||
"9083 1.0 5.0 3.637779 \n",
|
||||
"9094 1.0 5.0 3.637779 \n",
|
||||
"9096 1.0 5.0 3.637779 "
|
||||
]
|
||||
},
|
||||
"execution_count": 204,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"movies_df.sort_values(by=\"mean_rating\", ascending=False).head(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 201,
|
||||
"id": "f669fb09",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>movieId</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>genres</th>\n",
|
||||
" <th>num_ratings</th>\n",
|
||||
" <th>mean_rating</th>\n",
|
||||
" <th>damped_mean_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>277</th>\n",
|
||||
" <td>318</td>\n",
|
||||
" <td>Shawshank Redemption, The (1994)</td>\n",
|
||||
" <td>Crime|Drama</td>\n",
|
||||
" <td>317.0</td>\n",
|
||||
" <td>4.429022</td>\n",
|
||||
" <td>4.400659</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>659</th>\n",
|
||||
" <td>858</td>\n",
|
||||
" <td>Godfather, The (1972)</td>\n",
|
||||
" <td>Crime|Drama</td>\n",
|
||||
" <td>192.0</td>\n",
|
||||
" <td>4.289062</td>\n",
|
||||
" <td>4.250077</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2226</th>\n",
|
||||
" <td>2959</td>\n",
|
||||
" <td>Fight Club (1999)</td>\n",
|
||||
" <td>Action|Crime|Drama|Thriller</td>\n",
|
||||
" <td>218.0</td>\n",
|
||||
" <td>4.272936</td>\n",
|
||||
" <td>4.239103</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>922</th>\n",
|
||||
" <td>1221</td>\n",
|
||||
" <td>Godfather: Part II, The (1974)</td>\n",
|
||||
" <td>Crime|Drama</td>\n",
|
||||
" <td>129.0</td>\n",
|
||||
" <td>4.259690</td>\n",
|
||||
" <td>4.205148</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>46</th>\n",
|
||||
" <td>50</td>\n",
|
||||
" <td>Usual Suspects, The (1995)</td>\n",
|
||||
" <td>Crime|Mystery|Thriller</td>\n",
|
||||
" <td>204.0</td>\n",
|
||||
" <td>4.237745</td>\n",
|
||||
" <td>4.203344</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>224</th>\n",
|
||||
" <td>260</td>\n",
|
||||
" <td>Star Wars: Episode IV - A New Hope (1977)</td>\n",
|
||||
" <td>Action|Adventure|Sci-Fi</td>\n",
|
||||
" <td>251.0</td>\n",
|
||||
" <td>4.231076</td>\n",
|
||||
" <td>4.203125</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>602</th>\n",
|
||||
" <td>750</td>\n",
|
||||
" <td>Dr. Strangelove or: How I Learned to Stop Worr...</td>\n",
|
||||
" <td>Comedy|War</td>\n",
|
||||
" <td>97.0</td>\n",
|
||||
" <td>4.268041</td>\n",
|
||||
" <td>4.196407</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>914</th>\n",
|
||||
" <td>1213</td>\n",
|
||||
" <td>Goodfellas (1990)</td>\n",
|
||||
" <td>Crime|Drama</td>\n",
|
||||
" <td>126.0</td>\n",
|
||||
" <td>4.250000</td>\n",
|
||||
" <td>4.194967</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>461</th>\n",
|
||||
" <td>527</td>\n",
|
||||
" <td>Schindler's List (1993)</td>\n",
|
||||
" <td>Drama|War</td>\n",
|
||||
" <td>220.0</td>\n",
|
||||
" <td>4.225000</td>\n",
|
||||
" <td>4.193546</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6710</th>\n",
|
||||
" <td>58559</td>\n",
|
||||
" <td>Dark Knight, The (2008)</td>\n",
|
||||
" <td>Action|Crime|Drama|IMAX</td>\n",
|
||||
" <td>149.0</td>\n",
|
||||
" <td>4.238255</td>\n",
|
||||
" <td>4.191922</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" movieId title \n",
|
||||
"277 318 Shawshank Redemption, The (1994) \\\n",
|
||||
"659 858 Godfather, The (1972) \n",
|
||||
"2226 2959 Fight Club (1999) \n",
|
||||
"922 1221 Godfather: Part II, The (1974) \n",
|
||||
"46 50 Usual Suspects, The (1995) \n",
|
||||
"224 260 Star Wars: Episode IV - A New Hope (1977) \n",
|
||||
"602 750 Dr. Strangelove or: How I Learned to Stop Worr... \n",
|
||||
"914 1213 Goodfellas (1990) \n",
|
||||
"461 527 Schindler's List (1993) \n",
|
||||
"6710 58559 Dark Knight, The (2008) \n",
|
||||
"\n",
|
||||
" genres num_ratings mean_rating \n",
|
||||
"277 Crime|Drama 317.0 4.429022 \\\n",
|
||||
"659 Crime|Drama 192.0 4.289062 \n",
|
||||
"2226 Action|Crime|Drama|Thriller 218.0 4.272936 \n",
|
||||
"922 Crime|Drama 129.0 4.259690 \n",
|
||||
"46 Crime|Mystery|Thriller 204.0 4.237745 \n",
|
||||
"224 Action|Adventure|Sci-Fi 251.0 4.231076 \n",
|
||||
"602 Comedy|War 97.0 4.268041 \n",
|
||||
"914 Crime|Drama 126.0 4.250000 \n",
|
||||
"461 Drama|War 220.0 4.225000 \n",
|
||||
"6710 Action|Crime|Drama|IMAX 149.0 4.238255 \n",
|
||||
"\n",
|
||||
" damped_mean_rating \n",
|
||||
"277 4.400659 \n",
|
||||
"659 4.250077 \n",
|
||||
"2226 4.239103 \n",
|
||||
"922 4.205148 \n",
|
||||
"46 4.203344 \n",
|
||||
"224 4.203125 \n",
|
||||
"602 4.196407 \n",
|
||||
"914 4.194967 \n",
|
||||
"461 4.193546 \n",
|
||||
"6710 4.191922 "
|
||||
]
|
||||
},
|
||||
"execution_count": 201,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"movies_df.sort_values(by=\"damped_mean_rating\", ascending=False).head(10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "25aa1c78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "107e909b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the transactions data\n",
|
||||
"transactions = pd.read_csv(\"grocery_dataset.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "289a9751",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"itemDescription\n",
|
||||
"whole milk 515.0\n",
|
||||
"other vegetables 361.0\n",
|
||||
"rolls/buns 344.0\n",
|
||||
"soda 271.0\n",
|
||||
"yogurt 242.0\n",
|
||||
"dtype: float64"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"member_purchases = transactions.groupby(['Member_number', 'itemDescription'])['itemDescription'].count().unstack().fillna(0)\n",
|
||||
"item_descriptions = member_purchases.columns\n",
|
||||
"\n",
|
||||
"def simple_association(item_name):\n",
|
||||
" item_basket = member_purchases[member_purchases[item_name] > 0]\n",
|
||||
" co_purchase_counts = item_basket.sum().sort_values(ascending=False).drop(item_name)\n",
|
||||
" return co_purchase_counts.head(5)\n",
|
||||
"\n",
|
||||
"ex_item = item_descriptions[20]\n",
|
||||
"simple_association(ex_item)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "190a1485",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Top 10 recommendations for soda:\n",
|
||||
"\n",
|
||||
"itemDescription\n",
|
||||
"oil 1.246844\n",
|
||||
"beverages 1.162678\n",
|
||||
"sausage 1.014975\n",
|
||||
"grapes 1.001195\n",
|
||||
"shopping bags 0.95459\n",
|
||||
"frozen meals 0.943642\n",
|
||||
"specialty bar 0.936182\n",
|
||||
"butter 0.918418\n",
|
||||
"candy 0.910056\n",
|
||||
"specialty chocolate 0.904846\n",
|
||||
"Name: soda, dtype: object \n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Function to create a transaction matrix\n",
|
||||
"def create_transaction_matrix(transactions):\n",
|
||||
" # Group the transactions by member number, date and item description\n",
|
||||
" # Count the number of each item bought by each member on each date\n",
|
||||
" # Unstack the item descriptions to create a matrix where rows are transactions and columns are items\n",
|
||||
" # Fill any missing values with 0\n",
|
||||
" # Set the index to be the member number and date\n",
|
||||
" basket = (transactions.groupby(['Member_number', 'Date', 'itemDescription'])['itemDescription']\n",
|
||||
" .count().unstack().reset_index().fillna(0)\n",
|
||||
" .set_index(['Member_number', 'Date']))\n",
|
||||
" \n",
|
||||
" # Convert the counts to True or False\n",
|
||||
" # True if the item was bought in the transaction, False otherwise\n",
|
||||
" return basket.applymap(lambda x: True if x > 0 else False)\n",
|
||||
"\n",
|
||||
"# Function to calculate a lift matrix\n",
|
||||
"def calculate_lift_matrix(basket_sets, min_joint_probability=0.001):\n",
|
||||
" # Calculate the joint probability of each pair of items\n",
|
||||
" probability_pair = pd.DataFrame(index=basket_sets.columns, columns=basket_sets.columns)\n",
|
||||
" for item1 in basket_sets.columns:\n",
|
||||
" for item2 in basket_sets.columns:\n",
|
||||
" joint_probability = (basket_sets[item1] & basket_sets[item2]).sum() / len(basket_sets)\n",
|
||||
" if joint_probability > min_joint_probability:\n",
|
||||
" probability_pair.loc[item1, item2] = joint_probability\n",
|
||||
" else:\n",
|
||||
" probability_pair.loc[item1, item2] = 0\n",
|
||||
"\n",
|
||||
" # Set the diagonal of the joint probability matrix to 0\n",
|
||||
" np.fill_diagonal(probability_pair.values, 0)\n",
|
||||
"\n",
|
||||
" # Calculate the individual probability of each item\n",
|
||||
" probability_item = basket_sets.mean()\n",
|
||||
"\n",
|
||||
" # Calculate the product of the individual probabilities\n",
|
||||
" probability_product = np.outer(probability_item, probability_item)\n",
|
||||
"\n",
|
||||
" # Calculate the lift of each pair of items\n",
|
||||
" lift_matrix = probability_pair.divide(probability_product, fill_value=0)\n",
|
||||
" \n",
|
||||
" return lift_matrix\n",
|
||||
"\n",
|
||||
"# Function to recommend items\n",
|
||||
"def recommend_items(lift_matrix, item, num_recommendations=10):\n",
|
||||
" # Sort the items in the lift matrix for the given item in descending order\n",
|
||||
" # Take the top num_recommendations items\n",
|
||||
" recommended_for_item = lift_matrix[item].sort_values(ascending=False).head(num_recommendations)\n",
|
||||
" \n",
|
||||
" # Print the recommended items\n",
|
||||
" print(f\"Top {num_recommendations} recommendations for {item}:\\n\")\n",
|
||||
" print(recommended_for_item, \"\\n\\n\")\n",
|
||||
"\n",
|
||||
"# Create transaction matrix\n",
|
||||
"basket_sets = create_transaction_matrix(transactions)\n",
|
||||
"\n",
|
||||
"# Calculate the lift matrix\n",
|
||||
"lift_matrix = calculate_lift_matrix(basket_sets)\n",
|
||||
"\n",
|
||||
"# Recommend items for 'meat'\n",
|
||||
"recommend_items(lift_matrix, 'soda')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "b0c33033",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Posts sorted by Reddit's 'Hot' score:\n",
|
||||
" post_id post_title upvotes \n",
|
||||
"9 10 Difference between CNN and RNN? 350 \\\n",
|
||||
"8 9 The future of quantum computing in AI 600 \n",
|
||||
"7 8 Experience with multi-modal learning? 450 \n",
|
||||
"6 7 Looking for resources on probabilistic program... 700 \n",
|
||||
"4 5 Tips for handling imbalanced datasets? 1100 \n",
|
||||
"2 3 Has anyone tried the new reinforcement learnin... 900 \n",
|
||||
"3 4 Discussion: Evolutionary algorithms vs gradien... 800 \n",
|
||||
"5 6 Which GPU is best for neural network training? 300 \n",
|
||||
"0 1 How do I start with machine learning? 600 \n",
|
||||
"1 2 Best practices for deep learning optimization? 400 \n",
|
||||
"\n",
|
||||
" downvotes age_in_seconds age_in_hours reddit_hot hacker_news \n",
|
||||
"9 50 256000 71.111111 8.166010 0.042205 \n",
|
||||
"8 50 128000 35.555556 5.584807 0.227638 \n",
|
||||
"7 50 64000 17.777778 4.024282 0.559318 \n",
|
||||
"6 50 32000 8.888889 3.524024 2.416714 \n",
|
||||
"4 100 8000 2.222222 3.177778 18.779258 \n",
|
||||
"2 100 2000 0.555556 2.947534 38.776074 \n",
|
||||
"3 100 4000 1.111111 2.933987 24.453093 \n",
|
||||
"5 50 16000 4.444444 2.753496 2.886859 \n",
|
||||
"0 100 500 0.138889 2.710081 36.655710 \n",
|
||||
"1 50 1000 0.277778 2.566290 24.588946 \n",
|
||||
"\n",
|
||||
"Posts sorted by Hacker News score:\n",
|
||||
" post_id post_title upvotes \n",
|
||||
"2 3 Has anyone tried the new reinforcement learnin... 900 \\\n",
|
||||
"0 1 How do I start with machine learning? 600 \n",
|
||||
"1 2 Best practices for deep learning optimization? 400 \n",
|
||||
"3 4 Discussion: Evolutionary algorithms vs gradien... 800 \n",
|
||||
"4 5 Tips for handling imbalanced datasets? 1100 \n",
|
||||
"5 6 Which GPU is best for neural network training? 300 \n",
|
||||
"6 7 Looking for resources on probabilistic program... 700 \n",
|
||||
"7 8 Experience with multi-modal learning? 450 \n",
|
||||
"8 9 The future of quantum computing in AI 600 \n",
|
||||
"9 10 Difference between CNN and RNN? 350 \n",
|
||||
"\n",
|
||||
" downvotes age_in_seconds age_in_hours reddit_hot hacker_news \n",
|
||||
"2 100 2000 0.555556 2.947534 38.776074 \n",
|
||||
"0 100 500 0.138889 2.710081 36.655710 \n",
|
||||
"1 50 1000 0.277778 2.566290 24.588946 \n",
|
||||
"3 100 4000 1.111111 2.933987 24.453093 \n",
|
||||
"4 100 8000 2.222222 3.177778 18.779258 \n",
|
||||
"5 50 16000 4.444444 2.753496 2.886859 \n",
|
||||
"6 50 32000 8.888889 3.524024 2.416714 \n",
|
||||
"7 50 64000 17.777778 4.024282 0.559318 \n",
|
||||
"8 50 128000 35.555556 5.584807 0.227638 \n",
|
||||
"9 50 256000 71.111111 8.166010 0.042205 \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import math\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"data = {\n",
|
||||
" 'post_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
|
||||
" 'post_title': [\n",
|
||||
" \"How do I start with machine learning?\",\n",
|
||||
" \"Best practices for deep learning optimization?\",\n",
|
||||
" \"Has anyone tried the new reinforcement learning library?\",\n",
|
||||
" \"Discussion: Evolutionary algorithms vs gradient descent\",\n",
|
||||
" \"Tips for handling imbalanced datasets?\",\n",
|
||||
" \"Which GPU is best for neural network training?\",\n",
|
||||
" \"Looking for resources on probabilistic programming\",\n",
|
||||
" \"Experience with multi-modal learning?\",\n",
|
||||
" \"The future of quantum computing in AI\",\n",
|
||||
" \"Difference between CNN and RNN?\"\n",
|
||||
" ],\n",
|
||||
" 'upvotes': [600, 400, 900, 800, 1100, 300, 700, 450, 600, 350],\n",
|
||||
" 'downvotes': [100, 50, 100, 100, 100, 50, 50, 50, 50, 50],\n",
|
||||
" 'age_in_seconds': [500, 1000, 2000, 4000, 8000, 16000, 32000, 64000, 128000, 256000]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Convert to DataFrame\n",
|
||||
"reddit_df = pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
"# Calculate age in hours from age in seconds\n",
|
||||
"reddit_df['age_in_hours'] = reddit_df['age_in_seconds'] / 3600\n",
|
||||
"\n",
|
||||
"# Reddit's \"Hot\" formula\n",
|
||||
"def reddit_hot(U, D, t):\n",
|
||||
" return math.log10(max(abs(U-D), 1)) + np.sign(U-D) * t / 45000\n",
|
||||
"\n",
|
||||
"# Modified Hacker News formula\n",
|
||||
"def hacker_news(U, D, T, P=1, alpha=0.8, gamma=1.8):\n",
|
||||
" return P * pow((U - D - 1), alpha) / pow((T + 2), gamma)\n",
|
||||
"\n",
|
||||
"# Apply the formulas\n",
|
||||
"reddit_df['reddit_hot'] = reddit_df.apply(lambda row: reddit_hot(row['upvotes'], row['downvotes'], row['age_in_seconds']), axis=1)\n",
|
||||
"reddit_df['hacker_news'] = reddit_df.apply(lambda row: hacker_news(row['upvotes'], row['downvotes'], row['age_in_hours']), axis=1)\n",
|
||||
"\n",
|
||||
"# Sort by Reddit's \"Hot\" score and print\n",
|
||||
"reddit_df_sorted = reddit_df.sort_values(by='reddit_hot', ascending=False)\n",
|
||||
"print(\"Posts sorted by Reddit's 'Hot' score:\")\n",
|
||||
"print(reddit_df_sorted)\n",
|
||||
"\n",
|
||||
"# Sort by Hacker News score and print\n",
|
||||
"hacker_news_df_sorted = reddit_df.sort_values(by='hacker_news', ascending=False)\n",
|
||||
"print(\"\\nPosts sorted by Hacker News score:\")\n",
|
||||
"print(hacker_news_df_sorted)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
# TF-IDF from scratch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
62
README.md
62
README.md
@@ -5,9 +5,10 @@
|
||||
|
||||
[](https://travis-ci.com/aladdinpersson/Machine-Learning-Collection) [](https://opensource.org/licenses/MIT)
|
||||
|
||||
[logo]: https://github.com/AladdinPerzon/Machine-Learning-Collection/blob/master/ML/others/logo/youtube_logo.png
|
||||
[logo]: https://github.com/AladdinPerzon/Machine-Learning-Collection/blob/master/ML/others/logo/youtube_logo.png#
|
||||
|
||||
# Machine Learning Collection
|
||||
|
||||
In this repository you will find tutorials and projects related to Machine Learning. I try to make the code as clear as possible, and the goal is be to used as a learning resource and a way to lookup problems to solve specific problems. For most I have also done video explanations on YouTube if you want a walkthrough for the code. If you got any questions or suggestions for future videos I prefer if you ask it on [YouTube](https://www.youtube.com/c/AladdinPersson). This repository is contribution friendly, so if you feel you want to add something then I'd happily merge a PR :smiley:
|
||||
|
||||
## Table Of Contents
|
||||
@@ -22,6 +23,8 @@ In this repository you will find tutorials and projects related to Machine Learn
|
||||
- [TensorFlow Tutorials](#tensorflow-tutorials)
|
||||
- [Beginner Tutorials](#beginner-tutorials)
|
||||
- [Architectures](#CNN-Architectures)
|
||||
- [Docker setup](#Docker-Setup)
|
||||
|
||||
|
||||
## Machine Learning
|
||||
* [![Youtube Link][logo]](https://youtu.be/pCCUnoes1Po) [Linear Regression](https://github.com/AladdinPersson/Machine-Learning-Collection/blob/master/ML/algorithms/linearregression/linear_regression_gradient_descent.py) **- With Gradient Descent** :white_check_mark:
|
||||
@@ -32,7 +35,6 @@ In this repository you will find tutorials and projects related to Machine Learn
|
||||
* [![Youtube Link][logo]](https://youtu.be/W4fSRHeafMo) [K-means clustering](https://github.com/AladdinPersson/Machine-Learning-Collection/blob/master/ML/algorithms/kmeans/kmeansclustering.py)
|
||||
* [![Youtube Link][logo]](https://youtu.be/gBTtR0bs-1k) [Support Vector Machine](https://github.com/AladdinPersson/Machine-Learning-Collection/blob/master/ML/algorithms/svm/svm.py) **- Using CVXOPT**
|
||||
* [![Youtube Link][logo]](https://youtu.be/NJvojeoTnNM) [Neural Network](https://github.com/AladdinPersson/Machine-Learning-Collection/blob/master/ML/algorithms/neuralnetwork/NN.py)
|
||||
* [Decision Tree](https://github.com/AladdinPersson/Machine-Learning-Collection/blob/master/ML/algorithms/decisiontree/decision_tree.py)
|
||||
|
||||
## PyTorch Tutorials
|
||||
If you have any specific video suggestion please make a comment on YouTube :)
|
||||
@@ -139,10 +141,52 @@ If you have any specific video suggestion please make a comment on YouTube :)
|
||||
* [![Youtube Link][logo]](https://youtu.be/NoKvCREx36Q) [Tutorial 19 - Custom Dataset Text](https://github.com/AladdinPerzon/Machine-Learning-Collection/tree/master/ML/TensorFlow/Basics/tutorial19-customdata-text)
|
||||
* [![Youtube Link][logo]](https://youtu.be/ea5Z1smiR3U) [Tutorial 20 - Classifying Skin Cancer](https://github.com/AladdinPerzon/Machine-Learning-Collection/tree/master/ML/TensorFlow/Basics/tutorial20-classify-cancer-beginner-project-example) **- Beginner Project Example**
|
||||
|
||||
### CNN Architectures
|
||||
* [LeNet](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/TensorFlow/CNN_architectures/LeNet5)
|
||||
* [AlexNet](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/TensorFlow/CNN_architectures/AlexNet)
|
||||
* [VGG](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/TensorFlow/CNN_architectures/VGGNet)
|
||||
* [GoogLeNet](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/TensorFlow/CNN_architectures/GoogLeNet)
|
||||
* [ResNet](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/TensorFlow/CNN_architectures/ResNet)
|
||||
|
||||
## Docker Setup
|
||||
|
||||
### Step 1: Install Docker
|
||||
|
||||
If you don't have Docker installed, follow the links below to install Docker for your system:
|
||||
- [Install Docker Engine](https://docs.docker.com/engine/install/)
|
||||
|
||||
### Step 2: Install Nvidia Container Toolkit (Optional)
|
||||
If you plan to use GPU acceleration with CUDA, install Nvidia Container Toolkit:
|
||||
|
||||
- [Nvidia Container Toolkit Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
|
||||
|
||||
### Step 3: Build the Docker Image
|
||||
Navigate to the directory containing the Dockerfile and build the Docker image with:
|
||||
|
||||
```bash
|
||||
docker build -t aladdin-image .
|
||||
```
|
||||
|
||||
### Step 4: Run the Docker Container in Detached Mode
|
||||
Run the following command to start the container in detached mode:
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--gpus all \
|
||||
-v "${PWD}:/code" \
|
||||
-p 8080:8080 \
|
||||
--name "aladdin-container" \
|
||||
--env AUTHENTICATE_VIA_JUPYTER="mytoken" \
|
||||
aladdin-image \
|
||||
tail -f /dev/null
|
||||
```
|
||||
|
||||
This will start a new Docker container named `aladdin-container` that will not exit immediately. The `-d` flag runs the container in detached mode, letting it run in the background.
|
||||
|
||||
### Step 5: Interact with the Docker Container
|
||||
To attach an interactive shell to the running container, use the command:
|
||||
|
||||
```bash
|
||||
docker exec -it aladdin-container /bin/bash
|
||||
```
|
||||
|
||||
You can now interact with your container using the bash shell.
|
||||
|
||||
### Additional Notes
|
||||
- If you wish to stop the container, you can do so with `docker stop aladdin-container`.
|
||||
- To start the container again after stopping, use `docker start aladdin-container`.
|
||||
- In case you need to remove the container, make sure it's stopped and then run `docker rm aladdin-container`.
|
||||
- To see the output from the container (logs), use `docker logs aladdin-container`.
|
||||
|
||||
42
requirements.txt
Normal file
42
requirements.txt
Normal file
@@ -0,0 +1,42 @@
|
||||
# requirements.txt file with basic libraries to install for a machine learning workflow
|
||||
# HELLO
|
||||
numpy
|
||||
pandas
|
||||
scikit-learn
|
||||
matplotlib
|
||||
seaborn
|
||||
scipy
|
||||
|
||||
# deep learning
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
lightning
|
||||
torchmetrics
|
||||
|
||||
# all you need is xgboost
|
||||
xgboost
|
||||
lightgbm
|
||||
|
||||
# nlp libraries
|
||||
nltk
|
||||
spacy
|
||||
|
||||
# image processing
|
||||
opencv-python-headless
|
||||
Pillow
|
||||
|
||||
# data loading
|
||||
pyarrow
|
||||
|
||||
# model optimization/experiment tracking
|
||||
tensorboard
|
||||
wandb
|
||||
mlflow
|
||||
|
||||
# utilities
|
||||
tqdm
|
||||
|
||||
# notebooks
|
||||
jupyter
|
||||
ipywidgets
|
||||
Reference in New Issue
Block a user