diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
index b763a5c..e3220a5 100644
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
+++ b/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
@@ -2,3288 +2,22 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 6,
- "id": "87ef8027",
+ "execution_count": 2,
+ "id": "ec1aae37",
"metadata": {},
"outputs": [
{
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
+ "2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
+ "2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
+ ]
}
],
- "source": [
- "from jupyterthemes.stylefx import set_nb_theme\n",
- "set_nb_theme('chesterish')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "225eab36",
- "metadata": {},
- "outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
@@ -3312,8 +46,19 @@
},
{
"cell_type": "code",
- "execution_count": 18,
- "id": "9f7d2829",
+ "execution_count": 3,
+ "id": "5fd7cb0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"t5-small\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "04530b1e",
"metadata": {},
"outputs": [],
"source": [
@@ -3325,12 +70,12 @@
" \n",
" def prepare_data(self):\n",
" # Download and preprocess the data\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
+ " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
" \n",
" def setup(self, stage=None):\n",
" # Load and preprocess the data\n",
- " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
+ " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
"\n",
" self.train_ds = train_data.map(\n",
@@ -3364,8 +109,8 @@
},
{
"cell_type": "code",
- "execution_count": 19,
- "id": "a99bdbb0",
+ "execution_count": 7,
+ "id": "fbb699e1",
"metadata": {},
"outputs": [],
"source": [
@@ -3395,9 +140,8 @@
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
- " \n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=True)\n",
+ " self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
" return {'loss': loss, 'logits': logits}\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
@@ -3406,7 +150,7 @@
" labels = batch[\"labels\"]\n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits}\n",
+ " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" decoded_preds = []\n",
@@ -3430,26 +174,250 @@
},
{
"cell_type": "code",
- "execution_count": 20,
- "id": "3c28da7c",
- "metadata": {},
+ "execution_count": 8,
+ "id": "dd63c628",
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
- "ename": "TypeError",
- "evalue": "Trainer.__init__() got an unexpected keyword argument 'num_epochs'",
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "\n",
+ " 0%| | 0/1795 [00:00, ?ba/s]\u001b[A\n",
+ " 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
+ " 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
+ " 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
+ " 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
+ " 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
+ " 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
+ " 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
+ " 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
+ " 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
+ " 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
+ " 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
+ " 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
+ " 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
+ " 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
+ " 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
+ " 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
+ " 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
+ " 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
+ " 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
+ " 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
+ " 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
+ " 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
+ " 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
+ " 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
+ " 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
+ " 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
+ " 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
+ " 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
+ " 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
+ " 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
+ " 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
+ " 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
+ " 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
+ " 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
+ " 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
+ " 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
+ " 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
+ " 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
+ " 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
+ " 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
+ " 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
+ " 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
+ " 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
+ " 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
+ " 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
+ " 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
+ " 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
+ " 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
+ " 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
+ " 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
+ " 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
+ " 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
+ " 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
+ " 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
+ " 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
+ " 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
+ " 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
+ " 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
+ " 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
+ " 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
+ " 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
+ " 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
+ " 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
+ " 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
+ " 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
+ " 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
+ " 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
+ " 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
+ " 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
+ " 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
+ " 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
+ " 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
+ " 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
+ " 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
+ " 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
+ " 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
+ " 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
+ " 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
+ " 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
+ " 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
+ " 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
+ " 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
+ " 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
+ " 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
+ " 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
+ " 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
+ " 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
+ " 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
+ " 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
+ " 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
+ " 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
+ " 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
+ " 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
+ " 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
+ " 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
+ " 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
+ " 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
+ " 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
+ " 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
+ " 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
+ " 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
+ " 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
+ " 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
+ " 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
+ " 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
+ " 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
+ " 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
+ " 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
+ " 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
+ " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
+ " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
+ " 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
+ " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
+ " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
+ " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
+ " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
+ " 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
+ " 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
+ " 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
+ " 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
+ " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
+ " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
+ " 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
+ " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
+ " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
+ " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
+ " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
+ " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
+ " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
+ " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
+ " 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
+ " 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
+ " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
+ " 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
+ " 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
+ " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
+ " 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
+ " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
+ " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
+ " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
+ " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
+ " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
+ " 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
+ " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
+ " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
+ " 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
+ " 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
+ " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
+ " 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
+ " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
+ " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
+ "\n",
+ " 0%| | 0/84 [00:00, ?ba/s]\u001b[A\n",
+ " 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
+ " 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
+ " 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
+ " 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
+ " 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
+ " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "-----------------------------------------------------\n",
+ "0 | model | T5ForConditionalGeneration | 60.5 M\n",
+ "-----------------------------------------------------\n",
+ "60.5 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "60.5 M Total params\n",
+ "242.026 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "ename": "AttributeError",
+ "evalue": "'list' object has no attribute 'size'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m MyLightningModule(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt5-small\u001b[39m\u001b[38;5;124m\"\u001b[39m, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, weight_decay\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogger\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[1;32m 4\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(model, datamodule\u001b[38;5;241m=\u001b[39mdm)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:348\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 345\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m 347\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m--> 348\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "\u001b[0;31mTypeError\u001b[0m: Trainer.__init__() got an unexpected keyword argument 'num_epochs'"
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
]
}
],
"source": [
+ "torch.set_float32_matmul_precision(\"medium\")\n",
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
- "trainer = pl.Trainer(devices=[0], num_epochs=10, deterministic=True, logger=False)\n",
+ "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
"dm = MyDataModule(batch_size=16)\n",
"trainer.fit(model, datamodule=dm)"
]
@@ -3457,7 +425,15 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "55729d94",
+ "id": "1395d5d2",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "80a2efab",
"metadata": {},
"outputs": [],
"source": []
diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb
index 1a9ac79..c821b42 100644
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb
+++ b/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb
@@ -1,5 +1,69 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "7d5e92c6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
+ "from transformers import pipeline\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
+ "model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
+ "\n",
+ "pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
+ "example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
+ "\n",
+ "ner_entity_results = pipe(example)\n",
+ "print(ner_entity_results)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "bf67ee76",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Turtle s Original Cara mel Chocolate P eca n luster\n"
+ ]
+ }
+ ],
+ "source": [
+ "ner_entity_results = pipe(example)\n",
+ "\n",
+ "# Initialize the entity words list with an empty string\n",
+ "entity_words = [\"\"]\n",
+ "\n",
+ "# Loop through each dictionary in the list and extract the entity word\n",
+ "for result in ner_entity_results:\n",
+ " if result[\"entity\"] == \"B-FOOD\":\n",
+ " entity_words.append(result[\"word\"])\n",
+ " elif result[\"entity\"] == \"I-FOOD\":\n",
+ " entity_words[-1] += \" \" + result[\"word\"]\n",
+ "\n",
+ "# Remove any remaining ## symbols and extra spaces\n",
+ "entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
+ "\n",
+ "# Join the entity words into a single string\n",
+ "output = \" \".join(entity_words)\n",
+ "\n",
+ "print(output)\n"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb b/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb
index b763a5c..da1dc98 100644
--- a/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb
+++ b/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb
@@ -2,3288 +2,22 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 6,
- "id": "87ef8027",
+ "execution_count": 2,
+ "id": "ec1aae37",
"metadata": {},
"outputs": [
{
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
+ "2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
+ "2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
+ ]
}
],
- "source": [
- "from jupyterthemes.stylefx import set_nb_theme\n",
- "set_nb_theme('chesterish')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "225eab36",
- "metadata": {},
- "outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
@@ -3312,8 +46,19 @@
},
{
"cell_type": "code",
- "execution_count": 18,
- "id": "9f7d2829",
+ "execution_count": 3,
+ "id": "5fd7cb0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"t5-small\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "04530b1e",
"metadata": {},
"outputs": [],
"source": [
@@ -3325,12 +70,12 @@
" \n",
" def prepare_data(self):\n",
" # Download and preprocess the data\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
+ " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
" \n",
" def setup(self, stage=None):\n",
" # Load and preprocess the data\n",
- " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
+ " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
"\n",
" self.train_ds = train_data.map(\n",
@@ -3364,8 +109,8 @@
},
{
"cell_type": "code",
- "execution_count": 19,
- "id": "a99bdbb0",
+ "execution_count": 7,
+ "id": "fbb699e1",
"metadata": {},
"outputs": [],
"source": [
@@ -3395,9 +140,8 @@
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
- " \n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=True)\n",
+ " self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
" return {'loss': loss, 'logits': logits}\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
@@ -3406,7 +150,7 @@
" labels = batch[\"labels\"]\n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits}\n",
+ " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" decoded_preds = []\n",
@@ -3430,34 +174,273 @@
},
{
"cell_type": "code",
- "execution_count": 20,
- "id": "3c28da7c",
- "metadata": {},
+ "execution_count": 8,
+ "id": "dd63c628",
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
- "ename": "TypeError",
- "evalue": "Trainer.__init__() got an unexpected keyword argument 'num_epochs'",
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
+ "\n",
+ " 0%| | 0/1795 [00:00, ?ba/s]\u001b[A\n",
+ " 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
+ " 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
+ " 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
+ " 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
+ " 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
+ " 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
+ " 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
+ " 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
+ " 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
+ " 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
+ " 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
+ " 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
+ " 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
+ " 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
+ " 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
+ " 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
+ " 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
+ " 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
+ " 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
+ " 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
+ " 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
+ " 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
+ " 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
+ " 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
+ " 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
+ " 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
+ " 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
+ " 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
+ " 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
+ " 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
+ " 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
+ " 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
+ " 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
+ " 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
+ " 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
+ " 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
+ " 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
+ " 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
+ " 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
+ " 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
+ " 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
+ " 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
+ " 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
+ " 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
+ " 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
+ " 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
+ " 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
+ " 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
+ " 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
+ " 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
+ " 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
+ " 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
+ " 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
+ " 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
+ " 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
+ " 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
+ " 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
+ " 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
+ " 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
+ " 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
+ " 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
+ " 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
+ " 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
+ " 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
+ " 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
+ " 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
+ " 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
+ " 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
+ " 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
+ " 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
+ " 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
+ " 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
+ " 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
+ " 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
+ " 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
+ " 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
+ " 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
+ " 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
+ " 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
+ " 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
+ " 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
+ " 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
+ " 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
+ " 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
+ " 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
+ " 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
+ " 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
+ " 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
+ " 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
+ " 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
+ " 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
+ " 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
+ " 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
+ " 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
+ " 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
+ " 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
+ " 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
+ " 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
+ " 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
+ " 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
+ " 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
+ " 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
+ " 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
+ " 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
+ " 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
+ " 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
+ " 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
+ " 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
+ " 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
+ " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
+ " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
+ " 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
+ " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
+ " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
+ " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
+ " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
+ " 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
+ " 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
+ " 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
+ " 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
+ " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
+ " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
+ " 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
+ " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
+ " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
+ " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
+ " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
+ " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
+ " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
+ " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
+ " 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
+ " 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
+ " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
+ " 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
+ " 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
+ " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
+ " 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
+ " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
+ " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
+ " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
+ " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
+ " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
+ " 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
+ " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
+ " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
+ " 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
+ " 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
+ " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
+ " 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
+ " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
+ " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
+ "\n",
+ " 0%| | 0/84 [00:00, ?ba/s]\u001b[A\n",
+ " 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
+ " 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
+ " 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
+ " 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
+ " 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
+ " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "-----------------------------------------------------\n",
+ "0 | model | T5ForConditionalGeneration | 60.5 M\n",
+ "-----------------------------------------------------\n",
+ "60.5 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "60.5 M Total params\n",
+ "242.026 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "ename": "AttributeError",
+ "evalue": "'list' object has no attribute 'size'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m MyLightningModule(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt5-small\u001b[39m\u001b[38;5;124m\"\u001b[39m, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, weight_decay\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogger\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[1;32m 4\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(model, datamodule\u001b[38;5;241m=\u001b[39mdm)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:348\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 345\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m 347\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m--> 348\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "\u001b[0;31mTypeError\u001b[0m: Trainer.__init__() got an unexpected keyword argument 'num_epochs'"
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
]
}
],
"source": [
+ "torch.set_float32_matmul_precision(\"medium\")\n",
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
- "trainer = pl.Trainer(devices=[0], num_epochs=10, deterministic=True, logger=False)\n",
+ "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
"dm = MyDataModule(batch_size=16)\n",
"trainer.fit(model, datamodule=dm)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "aa7b1ab0",
+ "metadata": {},
+ "source": [
+ "### Recap of what we did:\n",
+ "* Finetuned T5-Small on DailyCNN (summarize news articles) using HF Trainer and data loading\n",
+ "* Converted to Lightning code \n",
+ "\n",
+ "### To do next:\n",
+ "* Make it work with the evaluation somethings wrong now, don't think it's a big issue\n",
+ "* Clean up the code a bit\n",
+ "* Compare it with HF, add predict function, modify data loading so it's from scratch / more general way of doing it."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
- "id": "55729d94",
+ "id": "80a2efab",
"metadata": {},
"outputs": [],
"source": []
diff --git a/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb b/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
index 8cfe998..09bebc9 100644
--- a/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
+++ b/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
@@ -2,3277 +2,10 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 7,
- "id": "bd8e3b95",
+ "execution_count": null,
+ "id": "5372055b",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"from jupyterthemes.stylefx import set_nb_theme\n",
"set_nb_theme('chesterish')"
@@ -3280,8 +13,8 @@
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "8c2a24cb",
+ "execution_count": null,
+ "id": "11214a4a",
"metadata": {},
"outputs": [],
"source": [
@@ -3292,24 +25,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "f45eb6b0",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "2023-02-21 15:40:52.888700: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
- "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2023-02-21 15:40:53.473104: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 15:40:53.473149: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 15:40:53.473154: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
@@ -3333,23 +52,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "7fc4eb40",
+ "execution_count": null,
+ "id": "b2d26af4",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
- "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
- "- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.\n",
- "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
- "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
- " warnings.warn(\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Load the pre-trained model and tokenizer\n",
"model_name = \"t5-small\"\n",
@@ -3359,21 +65,10 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "363045f5",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1122/1122 [02:06<00:00, 8.88ba/s]\n",
- "Loading cached processed dataset at /home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de/cache-2d3b7edd75fb1188.arrow\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"def preprocess_function(batch):\n",
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
@@ -3404,53 +99,10 @@
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "6faa8c86",
+ "execution_count": null,
+ "id": "0d58818f",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/tmp/ipykernel_478601/1088570042.py:23: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
- " metric = load_metric(\"rouge\")\n",
- "max_steps is given, it will override any value given in num_train_epochs\n",
- "Using cuda_amp half precision backend\n",
- "The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: id, article, highlights. If id, article, highlights are not expected by `T5ForConditionalGeneration.forward`, you can safely ignore this message.\n",
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
- " warnings.warn(\n",
- "***** Running training *****\n",
- " Num examples = 0\n",
- " Num Epochs = 1\n",
- " Instantaneous batch size per device = 16\n",
- " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
- " Gradient Accumulation steps = 1\n",
- " Total optimization steps = 5000\n",
- " Number of trainable parameters = 60506624\n"
- ]
- },
- {
- "ename": "IndexError",
- "evalue": "Invalid key: 90427 is out of bounds for size 0",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[6], line 47\u001b[0m\n\u001b[1;32m 36\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 37\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 38\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 43\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics,\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# Start the training\u001b[39;00m\n\u001b[0;32m---> 47\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1536\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1538\u001b[0m )\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1761\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1758\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_load_rng_state(resume_from_checkpoint)\n\u001b[1;32m 1760\u001b[0m step \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1761\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, inputs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(epoch_iterator):\n\u001b[1;32m 1762\u001b[0m \n\u001b[1;32m 1763\u001b[0m \u001b[38;5;66;03m# Skip past any already trained steps if resuming training\u001b[39;00m\n\u001b[1;32m 1764\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m steps_trained_in_current_epoch \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1765\u001b[0m steps_trained_in_current_epoch \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2599\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key): \u001b[38;5;66;03m# noqa: F811\u001b[39;00m\n\u001b[1;32m 2600\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> 2601\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2585\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, **kwargs)\u001b[0m\n\u001b[1;32m 2583\u001b[0m format_kwargs \u001b[38;5;241m=\u001b[39m format_kwargs \u001b[38;5;28;01mif\u001b[39;00m format_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m {}\n\u001b[1;32m 2584\u001b[0m formatter \u001b[38;5;241m=\u001b[39m get_formatter(format_type, features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mformat_kwargs)\n\u001b[0;32m-> 2585\u001b[0m pa_subtable \u001b[38;5;241m=\u001b[39m \u001b[43mquery_table\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 2586\u001b[0m formatted_output \u001b[38;5;241m=\u001b[39m format_table(\n\u001b[1;32m 2587\u001b[0m pa_subtable, key, formatter\u001b[38;5;241m=\u001b[39mformatter, format_columns\u001b[38;5;241m=\u001b[39mformat_columns, output_all_columns\u001b[38;5;241m=\u001b[39moutput_all_columns\n\u001b[1;32m 2588\u001b[0m )\n\u001b[1;32m 2589\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m formatted_output\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:588\u001b[0m, in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 587\u001b[0m size \u001b[38;5;241m=\u001b[39m indices\u001b[38;5;241m.\u001b[39mnum_rows \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m table\u001b[38;5;241m.\u001b[39mnum_rows\n\u001b[0;32m--> 588\u001b[0m \u001b[43m_check_valid_index_key\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[38;5;66;03m# Query the main table\u001b[39;00m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:531\u001b[0m, in \u001b[0;36m_check_valid_index_key\u001b[0;34m(key, size)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (key \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m key \u001b[38;5;241m+\u001b[39m size \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (key \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m size):\n\u001b[0;32m--> 531\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid key: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is out of bounds for size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msize\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mslice\u001b[39m):\n",
- "\u001b[0;31mIndexError\u001b[0m: Invalid key: 90427 is out of bounds for size 0"
- ]
- }
- ],
+ "outputs": [],
"source": [
"class MyLightningModule(pl.LightningModule):\n",
" def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
@@ -3531,7 +183,7 @@
},
{
"cell_type": "markdown",
- "id": "1b0f9a76",
+ "id": "5148159b",
"metadata": {},
"source": [
"# Steps:\n",
@@ -3545,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ff03c8bb",
+ "id": "95e33e40",
"metadata": {},
"outputs": [],
"source": [
@@ -3555,7 +207,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "aafc4b27",
+ "id": "4c0348c2",
"metadata": {},
"outputs": [],
"source": []
diff --git a/ML/Pytorch/huggingface/learning.ipynb b/ML/Pytorch/huggingface/learning.ipynb
index 1a9ac79..c821b42 100644
--- a/ML/Pytorch/huggingface/learning.ipynb
+++ b/ML/Pytorch/huggingface/learning.ipynb
@@ -1,5 +1,69 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "7d5e92c6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
+ "from transformers import pipeline\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
+ "model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
+ "\n",
+ "pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
+ "example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
+ "\n",
+ "ner_entity_results = pipe(example)\n",
+ "print(ner_entity_results)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "bf67ee76",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Turtle s Original Cara mel Chocolate P eca n luster\n"
+ ]
+ }
+ ],
+ "source": [
+ "ner_entity_results = pipe(example)\n",
+ "\n",
+ "# Initialize the entity words list with an empty string\n",
+ "entity_words = [\"\"]\n",
+ "\n",
+ "# Loop through each dictionary in the list and extract the entity word\n",
+ "for result in ner_entity_results:\n",
+ " if result[\"entity\"] == \"B-FOOD\":\n",
+ " entity_words.append(result[\"word\"])\n",
+ " elif result[\"entity\"] == \"I-FOOD\":\n",
+ " entity_words[-1] += \" \" + result[\"word\"]\n",
+ "\n",
+ "# Remove any remaining ## symbols and extra spaces\n",
+ "entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
+ "\n",
+ "# Join the entity words into a single string\n",
+ "output = \" \".join(entity_words)\n",
+ "\n",
+ "print(output)\n"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0 b/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0
new file mode 100644
index 0000000..ab84c26
Binary files /dev/null and b/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0 differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0 b/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0
new file mode 100644
index 0000000..9ac0e2a
Binary files /dev/null and b/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0 differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0 b/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0
new file mode 100644
index 0000000..256eb1d
Binary files /dev/null and b/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0 differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1 b/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1
new file mode 100644
index 0000000..a60d00b
Binary files /dev/null and b/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1 differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/ML/Pytorch/pytorch_lightning/3. Lightning Trainer/simple_fc.py b/ML/Pytorch/pytorch_lightning/3. Lightning Trainer/simple_fc.py
index 6bdcef3..8ea9065 100644
--- a/ML/Pytorch/pytorch_lightning/3. Lightning Trainer/simple_fc.py
+++ b/ML/Pytorch/pytorch_lightning/3. Lightning Trainer/simple_fc.py
@@ -6,7 +6,7 @@ from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
-import pytorch_lightning as pl
+import pytorch_lightning as pl
class NN(pl.LightningModule):
@@ -23,28 +23,28 @@ class NN(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
- self.log('train_loss', loss)
+ self.log("train_loss", loss)
return loss
-
+
def validation_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
- self.log('val_loss', loss)
+ self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
- self.log('test_loss', loss)
+ self.log("test_loss", loss)
return loss
def _common_step(self, batch, batch_idx):
- x, y = batch
+ x, y = batch
x = x.reshape(x.size(0), -1)
scores = self.forward(x)
loss = self.loss_fn(scores, y)
return loss, scores, y
def predict_step(self, batch, batch_idx):
- x, y = batch
+ x, y = batch
x = x.reshape(x.size(0), -1)
scores = self.forward(x)
preds = torch.argmax(scores, dim=1)
@@ -53,6 +53,7 @@ class NN(pl.LightningModule):
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)
+
# Set device cuda for GPU if it's available otherwise run on the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -82,7 +83,13 @@ model = NN(input_size=input_size, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
-trainer = pl.Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
+trainer = pl.Trainer(
+ accelerator="gpu",
+ devices=1,
+ min_epochs=1,
+ max_epochs=3,
+ precision=16,
+)
trainer.fit(model, train_loader, val_loader)
trainer.validate(model, val_loader)
trainer.test(model, test_loader)