{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Architecture particuliere\n",
"\n",
"Dans ce tutoriel, nous allons voir comment combiner différentes entrées dans un réseau de neurones."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import pandas as pd\n",
"import numpy as np\n",
"import nltk\n",
"from nltk.tokenize import word_tokenize\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torch import optim, nn\n",
"from torch.autograd import Variable\n",
"from pytoune.framework import Model, ModelCheckpoint, Callback, CSVLogger, EarlyStopping, ReduceLROnPlateau\n",
"from pytoune.framework.metrics import acc\n",
"import torch\n",
"\n",
"# nltk.download('punkt')\n",
"torch.manual_seed(42)\n",
"np.random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"cuda_device = 0\n",
"device = torch.device(\"cuda:%d\" % cuda_device if torch.cuda.is_available() else \"cpu\")\n",
"batch_size = 32\n",
"learning_rate = 0.01\n",
"n_epoch = 10"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Unnamed: 0 | \n",
" country | \n",
" description | \n",
" designation | \n",
" points | \n",
" price | \n",
" province | \n",
" region_1 | \n",
" region_2 | \n",
" taster_name | \n",
" taster_twitter_handle | \n",
" title | \n",
" variety | \n",
" winery | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" Italy | \n",
" Aromas include tropical fruit, broom, brimston... | \n",
" Vulkà Bianco | \n",
" 87 | \n",
" NaN | \n",
" Sicily & Sardinia | \n",
" Etna | \n",
" NaN | \n",
" Kerin O’Keefe | \n",
" @kerinokeefe | \n",
" Nicosia 2013 Vulkà Bianco (Etna) | \n",
" White Blend | \n",
" Nicosia | \n",
"
\n",
" \n",
" 1 | \n",
" 1 | \n",
" Portugal | \n",
" This is ripe and fruity, a wine that is smooth... | \n",
" Avidagos | \n",
" 87 | \n",
" 15.0 | \n",
" Douro | \n",
" NaN | \n",
" NaN | \n",
" Roger Voss | \n",
" @vossroger | \n",
" Quinta dos Avidagos 2011 Avidagos Red (Douro) | \n",
" Portuguese Red | \n",
" Quinta dos Avidagos | \n",
"
\n",
" \n",
" 2 | \n",
" 2 | \n",
" US | \n",
" Tart and snappy, the flavors of lime flesh and... | \n",
" NaN | \n",
" 87 | \n",
" 14.0 | \n",
" Oregon | \n",
" Willamette Valley | \n",
" Willamette Valley | \n",
" Paul Gregutt | \n",
" @paulgwine | \n",
" Rainstorm 2013 Pinot Gris (Willamette Valley) | \n",
" Pinot Gris | \n",
" Rainstorm | \n",
"
\n",
" \n",
" 3 | \n",
" 3 | \n",
" US | \n",
" Pineapple rind, lemon pith and orange blossom ... | \n",
" Reserve Late Harvest | \n",
" 87 | \n",
" 13.0 | \n",
" Michigan | \n",
" Lake Michigan Shore | \n",
" NaN | \n",
" Alexander Peartree | \n",
" NaN | \n",
" St. Julian 2013 Reserve Late Harvest Riesling ... | \n",
" Riesling | \n",
" St. Julian | \n",
"
\n",
" \n",
" 4 | \n",
" 4 | \n",
" US | \n",
" Much like the regular bottling from 2012, this... | \n",
" Vintner's Reserve Wild Child Block | \n",
" 87 | \n",
" 65.0 | \n",
" Oregon | \n",
" Willamette Valley | \n",
" Willamette Valley | \n",
" Paul Gregutt | \n",
" @paulgwine | \n",
" Sweet Cheeks 2012 Vintner's Reserve Wild Child... | \n",
" Pinot Noir | \n",
" Sweet Cheeks | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Unnamed: 0 country description \\\n",
"0 0 Italy Aromas include tropical fruit, broom, brimston... \n",
"1 1 Portugal This is ripe and fruity, a wine that is smooth... \n",
"2 2 US Tart and snappy, the flavors of lime flesh and... \n",
"3 3 US Pineapple rind, lemon pith and orange blossom ... \n",
"4 4 US Much like the regular bottling from 2012, this... \n",
"\n",
" designation points price province \\\n",
"0 Vulkà Bianco 87 NaN Sicily & Sardinia \n",
"1 Avidagos 87 15.0 Douro \n",
"2 NaN 87 14.0 Oregon \n",
"3 Reserve Late Harvest 87 13.0 Michigan \n",
"4 Vintner's Reserve Wild Child Block 87 65.0 Oregon \n",
"\n",
" region_1 region_2 taster_name \\\n",
"0 Etna NaN Kerin O’Keefe \n",
"1 NaN NaN Roger Voss \n",
"2 Willamette Valley Willamette Valley Paul Gregutt \n",
"3 Lake Michigan Shore NaN Alexander Peartree \n",
"4 Willamette Valley Willamette Valley Paul Gregutt \n",
"\n",
" taster_twitter_handle title \\\n",
"0 @kerinokeefe Nicosia 2013 Vulkà Bianco (Etna) \n",
"1 @vossroger Quinta dos Avidagos 2011 Avidagos Red (Douro) \n",
"2 @paulgwine Rainstorm 2013 Pinot Gris (Willamette Valley) \n",
"3 NaN St. Julian 2013 Reserve Late Harvest Riesling ... \n",
"4 @paulgwine Sweet Cheeks 2012 Vintner's Reserve Wild Child... \n",
"\n",
" variety winery \n",
"0 White Blend Nicosia \n",
"1 Portuguese Red Quinta dos Avidagos \n",
"2 Pinot Gris Rainstorm \n",
"3 Riesling St. Julian \n",
"4 Pinot Noir Sweet Cheeks "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('./winemag-data-130k-v2.csv')\n",
"data.head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nous pouvons voir dans la cellules suivantes qu'il y a plusieurs doublons dans les données."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Unnamed: 0 | \n",
" country | \n",
" description | \n",
" designation | \n",
" points | \n",
" price | \n",
" province | \n",
" region_1 | \n",
" region_2 | \n",
" taster_name | \n",
" taster_twitter_handle | \n",
" title | \n",
" variety | \n",
" winery | \n",
"
\n",
" \n",
" \n",
" \n",
" 67614 | \n",
" 67614 | \n",
" US | \n",
" 100% Malbec, it's redolent with dark plums, wi... | \n",
" NaN | \n",
" 87 | \n",
" 20.0 | \n",
" Washington | \n",
" Rattlesnake Hills | \n",
" Columbia Valley | \n",
" Sean P. Sullivan | \n",
" @wawinereport | \n",
" Roza Ridge 2010 Malbec (Rattlesnake Hills) | \n",
" Malbec | \n",
" Roza Ridge | \n",
"
\n",
" \n",
" 46540 | \n",
" 46540 | \n",
" US | \n",
" 100% Malbec, it's redolent with dark plums, wi... | \n",
" NaN | \n",
" 87 | \n",
" 20.0 | \n",
" Washington | \n",
" Rattlesnake Hills | \n",
" Columbia Valley | \n",
" Sean P. Sullivan | \n",
" @wawinereport | \n",
" Roza Ridge 2010 Malbec (Rattlesnake Hills) | \n",
" Malbec | \n",
" Roza Ridge | \n",
"
\n",
" \n",
" 119702 | \n",
" 119702 | \n",
" US | \n",
" 100% Sangiovese, this pale pink wine has notes... | \n",
" Meadow | \n",
" 88 | \n",
" 18.0 | \n",
" Washington | \n",
" Columbia Valley (WA) | \n",
" Columbia Valley | \n",
" Sean P. Sullivan | \n",
" @wawinereport | \n",
" Ross Andrew 2013 Meadow Rosé (Columbia Valley ... | \n",
" Rosé | \n",
" Ross Andrew | \n",
"
\n",
" \n",
" 72181 | \n",
" 72181 | \n",
" US | \n",
" 100% Sangiovese, this pale pink wine has notes... | \n",
" Meadow | \n",
" 88 | \n",
" 18.0 | \n",
" Washington | \n",
" Columbia Valley (WA) | \n",
" Columbia Valley | \n",
" Sean P. Sullivan | \n",
" @wawinereport | \n",
" Ross Andrew 2013 Meadow Rosé (Columbia Valley ... | \n",
" Rosé | \n",
" Ross Andrew | \n",
"
\n",
" \n",
" 73731 | \n",
" 73731 | \n",
" France | \n",
" 87-89 Barrel sample. A pleasurable, perfumed w... | \n",
" Barrel sample | \n",
" 88 | \n",
" NaN | \n",
" Bordeaux | \n",
" Saint-Julien | \n",
" NaN | \n",
" Roger Voss | \n",
" @vossroger | \n",
" Château Lalande-Borie 2008 Barrel sample (Sai... | \n",
" Bordeaux-style Red Blend | \n",
" Château Lalande-Borie | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Unnamed: 0 country description \\\n",
"67614 67614 US 100% Malbec, it's redolent with dark plums, wi... \n",
"46540 46540 US 100% Malbec, it's redolent with dark plums, wi... \n",
"119702 119702 US 100% Sangiovese, this pale pink wine has notes... \n",
"72181 72181 US 100% Sangiovese, this pale pink wine has notes... \n",
"73731 73731 France 87-89 Barrel sample. A pleasurable, perfumed w... \n",
"\n",
" designation points price province region_1 \\\n",
"67614 NaN 87 20.0 Washington Rattlesnake Hills \n",
"46540 NaN 87 20.0 Washington Rattlesnake Hills \n",
"119702 Meadow 88 18.0 Washington Columbia Valley (WA) \n",
"72181 Meadow 88 18.0 Washington Columbia Valley (WA) \n",
"73731 Barrel sample 88 NaN Bordeaux Saint-Julien \n",
"\n",
" region_2 taster_name taster_twitter_handle \\\n",
"67614 Columbia Valley Sean P. Sullivan @wawinereport \n",
"46540 Columbia Valley Sean P. Sullivan @wawinereport \n",
"119702 Columbia Valley Sean P. Sullivan @wawinereport \n",
"72181 Columbia Valley Sean P. Sullivan @wawinereport \n",
"73731 NaN Roger Voss @vossroger \n",
"\n",
" title \\\n",
"67614 Roza Ridge 2010 Malbec (Rattlesnake Hills) \n",
"46540 Roza Ridge 2010 Malbec (Rattlesnake Hills) \n",
"119702 Ross Andrew 2013 Meadow Rosé (Columbia Valley ... \n",
"72181 Ross Andrew 2013 Meadow Rosé (Columbia Valley ... \n",
"73731 Château Lalande-Borie 2008 Barrel sample (Sai... \n",
"\n",
" variety winery \n",
"67614 Malbec Roza Ridge \n",
"46540 Malbec Roza Ridge \n",
"119702 Rosé Ross Andrew \n",
"72181 Rosé Ross Andrew \n",
"73731 Bordeaux-style Red Blend Château Lalande-Borie "
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[data.duplicated('description',keep=False)].sort_values('description').head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pour l'entraînement de notre réseau, nous allons nous assurer de conserver les lignes qui ont des données.\n",
"\n",
"Nous pourrions utiliser des techniques vues précédemment pour gérer les valeurs nulles."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(88244, 14)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.drop_duplicates('description')\n",
"data = data[pd.notnull(data.price)]\n",
"data = data[pd.notna(data.country)]\n",
"data = data[pd.notna(data.points)]\n",
"data = data[pd.notna(data.taster_name)]\n",
"data.shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 'Portugal',\n",
" \"This is ripe and fruity, a wine that is smooth while still structured. Firm tannins are filled out with juicy red berry fruits and freshened with acidity. It's already drinkable, although it will certainly be better from 2016.\",\n",
" ..., 'Quinta dos Avidagos 2011 Avidagos Red (Douro)',\n",
" 'Portuguese Red', 'Quinta dos Avidagos'],\n",
" [2, 'US',\n",
" 'Tart and snappy, the flavors of lime flesh and rind dominate. Some green pineapple pokes through, with crisp acidity underscoring the flavors. The wine was all stainless-steel fermented.',\n",
" ..., 'Rainstorm 2013 Pinot Gris (Willamette Valley)',\n",
" 'Pinot Gris', 'Rainstorm'],\n",
" [3, 'US',\n",
" 'Pineapple rind, lemon pith and orange blossom start off the aromas. The palate is a bit more opulent, with notes of honey-drizzled guava and mango giving way to a slightly astringent, semidry finish.',\n",
" ...,\n",
" 'St. Julian 2013 Reserve Late Harvest Riesling (Lake Michigan Shore)',\n",
" 'Riesling', 'St. Julian'],\n",
" ...,\n",
" [129968, 'France',\n",
" 'Well-drained gravel soil gives this wine its crisp and dry character. It is ripe and fruity, although the spice is subdued in favor of a more serious structure. This is a wine to age for a couple of years, so drink from 2017.',\n",
" ..., 'Domaine Gresser 2013 Kritt Gewurztraminer (Alsace)',\n",
" 'Gewürztraminer', 'Domaine Gresser'],\n",
" [129969, 'France',\n",
" 'A dry style of Pinot Gris, this is crisp with some acidity. It also has weight and a solid, powerful core of spice and baked apple flavors. With its structure still developing, the wine needs to age. Drink from 2015.',\n",
" ..., 'Domaine Marcel Deiss 2012 Pinot Gris (Alsace)',\n",
" 'Pinot Gris', 'Domaine Marcel Deiss'],\n",
" [129970, 'France',\n",
" 'Big, rich and off-dry, this is powered by intense spiciness and rounded texture. Lychees dominate the fruit profile, giving an opulent feel to the aftertaste. Drink now.',\n",
" ...,\n",
" 'Domaine Schoffit 2012 Lieu-dit Harth Cuvée Caroline Gewurztraminer (Alsace)',\n",
" 'Gewürztraminer', 'Domaine Schoffit']], dtype=object)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.values"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"num_in_train = round(0.8*len(data))\n",
"num_in_valid = round(0.1*len(data))\n",
"num_in_test = len(data) - (num_in_train + num_in_valid)\n",
"train, valid, test = random_split(data.values, [num_in_train, num_in_valid, num_in_test])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70595, 8824, 8825)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train), len(valid), len(test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"En premier lieu nous allons sélectionner le pays, la description et le nombre de points pour tenter de classifer le goûteur.\n",
"\n",
"Il est laissé au lecteur d'utilier les autres colonnes pour peaufiner le modèle."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def filter_dataset(data):\n",
" f = list()\n",
" for example in data:\n",
" # 1: Country\n",
" # 2: Description\n",
" # 4: Points\n",
" # 9: Taster\n",
" e = ((example[1], [w.lower() for w in word_tokenize(example[2])], example[4]/100), example[9])\n",
" f.append(e)\n",
" return f\n",
" \n",
"train_formatted = filter_dataset(train)\n",
"valid_formatted = filter_dataset(valid)\n",
"test_formatted = filter_dataset(test)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(('Chile',\n",
" ['starts',\n",
" 'out',\n",
" 'a',\n",
" 'little',\n",
" 'sharp',\n",
" 'and',\n",
" 'heavy',\n",
" 'on',\n",
" 'the',\n",
" 'olive',\n",
" 'aromas',\n",
" ',',\n",
" 'but',\n",
" 'it',\n",
" 'brings',\n",
" 'enough',\n",
" 'freshness',\n",
" 'to',\n",
" 'override',\n",
" 'any',\n",
" 'roasted',\n",
" ',',\n",
" 'herbal',\n",
" 'qualities',\n",
" 'that',\n",
" 'come',\n",
" 'with',\n",
" 'the',\n",
" 'variety',\n",
" '.',\n",
" 'the',\n",
" 'palate',\n",
" 'is',\n",
" 'big',\n",
" 'and',\n",
" 'chunky',\n",
" ',',\n",
" 'with',\n",
" 'sweet',\n",
" ',',\n",
" 'jammy',\n",
" 'black',\n",
" 'fruit',\n",
" 'flavors',\n",
" '.',\n",
" 'long',\n",
" ',',\n",
" 'chocolaty',\n",
" 'and',\n",
" 'slightly',\n",
" 'herbal',\n",
" 'on',\n",
" 'the',\n",
" 'finish',\n",
" '.'],\n",
" 0.87),\n",
" 'Michael Schachner'),\n",
" (('France',\n",
" ['on',\n",
" 'the',\n",
" 'rich',\n",
" 'side',\n",
" ',',\n",
" 'this',\n",
" 'smooth',\n",
" 'wine',\n",
" 'discloses',\n",
" 'flavors',\n",
" 'of',\n",
" 'caramel',\n",
" 'and',\n",
" 'toasted',\n",
" 'almonds',\n",
" 'as',\n",
" 'well',\n",
" 'as',\n",
" 'fruit',\n",
" '.',\n",
" 'it',\n",
" 'leaves',\n",
" 'a',\n",
" 'sweet',\n",
" 'feeling',\n",
" ',',\n",
" 'rounded',\n",
" 'and',\n",
" 'ripe',\n",
" ',',\n",
" 'finished',\n",
" 'with',\n",
" 'a',\n",
" 'touch',\n",
" 'of',\n",
" 'pepper',\n",
" 'and',\n",
" 'spice',\n",
" '.'],\n",
" 0.85),\n",
" 'Roger Voss'),\n",
" (('France',\n",
" ['round',\n",
" 'and',\n",
" 'soft',\n",
" ',',\n",
" 'this',\n",
" 'warm',\n",
" 'wine',\n",
" 'from',\n",
" 'the',\n",
" 'southern',\n",
" 'end',\n",
" 'of',\n",
" 'the',\n",
" 'côte',\n",
" 'de',\n",
" 'beaune',\n",
" 'feels',\n",
" 'rich',\n",
" 'and',\n",
" 'generous',\n",
" '.',\n",
" 'it',\n",
" 'has',\n",
" 'caramel',\n",
" 'and',\n",
" 'yellow',\n",
" 'fruit',\n",
" 'flavors',\n",
" ',',\n",
" 'and',\n",
" 'its',\n",
" 'crisp',\n",
" 'acidity',\n",
" 'will',\n",
" 'ensure',\n",
" 'that',\n",
" 'it',\n",
" 'will',\n",
" 'remain',\n",
" 'fresh',\n",
" 'and',\n",
" 'fruity',\n",
" '.'],\n",
" 0.87),\n",
" 'Roger Voss'),\n",
" (('Argentina',\n",
" ['herbal',\n",
" 'aromas',\n",
" 'bring',\n",
" 'a',\n",
" 'whiff',\n",
" 'of',\n",
" 'compost',\n",
" 'and',\n",
" 'leather',\n",
" '.',\n",
" 'a',\n",
" 'wiry',\n",
" ',',\n",
" 'dilute',\n",
" 'palate',\n",
" 'shows',\n",
" 'green',\n",
" 'tannins',\n",
" ',',\n",
" 'while',\n",
" 'flavors',\n",
" 'of',\n",
" 'buttery',\n",
" ',',\n",
" 'almost',\n",
" 'greasy',\n",
" 'oak',\n",
" 'and',\n",
" 'cranberry',\n",
" 'finish',\n",
" 'weedy',\n",
" 'and',\n",
" 'weakly',\n",
" '.'],\n",
" 0.82),\n",
" 'Michael Schachner'),\n",
" (('US',\n",
" ['the',\n",
" 'aromas',\n",
" 'of',\n",
" 'roasted',\n",
" 'coffee',\n",
" 'bean',\n",
" ',',\n",
" 'chocolate',\n",
" ',',\n",
" 'and',\n",
" 'dark',\n",
" 'raspberry',\n",
" 'are',\n",
" 'light',\n",
" ',',\n",
" 'while',\n",
" 'the',\n",
" 'palate',\n",
" 'is',\n",
" 'elegant',\n",
" 'in',\n",
" 'style',\n",
" 'with',\n",
" 'a',\n",
" 'good',\n",
" 'sense',\n",
" 'of',\n",
" 'balance',\n",
" 'and',\n",
" 'a',\n",
" 'lingering',\n",
" 'finish',\n",
" '.'],\n",
" 0.89),\n",
" 'Sean P. Sullivan')]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_formatted[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Comme vu dans le précédent tutoriel, nous devons nous créer un vocabulaire pour toute donnée \"non numérique\""
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"description_vocab = set()\n",
"country_vocab = set()\n",
"taster_vocab = set()\n",
"\n",
"for e in train_formatted:\n",
" (country, description, points), taster = e\n",
" country_vocab.add(country)\n",
" for word in description:\n",
" description_vocab.add(word)\n",
" \n",
"# We make sure to catch all tasters\n",
"for e in (train_formatted + valid_formatted + test_formatted):\n",
" (country, description, points), taster = e\n",
" taster_vocab.add(taster)\n",
"\n",
"word_to_idx = {\n",
" '': 0,\n",
" '': 1,\n",
"}\n",
"\n",
"for word in sorted(description_vocab):\n",
" word_to_idx[word] = len(word_to_idx)\n",
"\n",
"\n",
"country_to_idx = {country: i for i, country in enumerate(sorted(country_vocab))}\n",
"country_to_idx[''] = len(country_to_idx)\n",
"\n",
"taster_to_idx = {taster: i for i, taster in enumerate(sorted(taster_vocab))}"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"34879"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(word_to_idx)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Argentina': 0,\n",
" 'Armenia': 1,\n",
" 'Australia': 2,\n",
" 'Austria': 3,\n",
" 'Bosnia and Herzegovina': 4,\n",
" 'Brazil': 5,\n",
" 'Bulgaria': 6,\n",
" 'Canada': 7,\n",
" 'Chile': 8,\n",
" 'China': 9,\n",
" 'Croatia': 10,\n",
" 'Cyprus': 11,\n",
" 'Czech Republic': 12,\n",
" 'England': 13,\n",
" 'France': 14,\n",
" 'Georgia': 15,\n",
" 'Germany': 16,\n",
" 'Greece': 17,\n",
" 'Hungary': 18,\n",
" 'India': 19,\n",
" 'Israel': 20,\n",
" 'Italy': 21,\n",
" 'Lebanon': 22,\n",
" 'Luxembourg': 23,\n",
" 'Macedonia': 24,\n",
" 'Mexico': 25,\n",
" 'Moldova': 26,\n",
" 'Morocco': 27,\n",
" 'New Zealand': 28,\n",
" 'Peru': 29,\n",
" 'Portugal': 30,\n",
" 'Romania': 31,\n",
" 'Serbia': 32,\n",
" 'Slovakia': 33,\n",
" 'Slovenia': 34,\n",
" 'South Africa': 35,\n",
" 'Spain': 36,\n",
" 'Switzerland': 37,\n",
" 'Turkey': 38,\n",
" 'US': 39,\n",
" 'Ukraine': 40,\n",
" 'Uruguay': 41,\n",
" '': 42}"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"country_to_idx"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Alexander Peartree': 0,\n",
" 'Anna Lee C. Iijima': 1,\n",
" 'Anne Krebiehl\\xa0MW': 2,\n",
" 'Carrie Dykes': 3,\n",
" 'Christina Pickard': 4,\n",
" 'Fiona Adams': 5,\n",
" 'Jeff Jenssen': 6,\n",
" 'Jim Gordon': 7,\n",
" 'Joe Czerwinski': 8,\n",
" 'Kerin O’Keefe': 9,\n",
" 'Lauren Buzzeo': 10,\n",
" 'Matt Kettmann': 11,\n",
" 'Michael Schachner': 12,\n",
" 'Mike DeSimone': 13,\n",
" 'Paul Gregutt': 14,\n",
" 'Roger Voss': 15,\n",
" 'Sean P. Sullivan': 16,\n",
" 'Susan Kostrzewa': 17,\n",
" 'Virginie Boone': 18}"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"taster_to_idx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ce vectorizer va nous servir à convertir toute donnée 'non numérique' en donnée numérique."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"class Vectorizer:\n",
" def __init__(self, word_to_idx, country_to_idx, taster_to_idx):\n",
" self.word_to_idx = word_to_idx\n",
" self.country_to_idx = country_to_idx\n",
" self.taster_to_idx = taster_to_idx\n",
" \n",
"\n",
" def vectorize_sequence(self, sequence, idx, remove_if_unk=False):\n",
" if '' in idx:\n",
" unknown_index = idx['']\n",
" chars = [idx.get(tok, unknown_index) for tok in sequence]\n",
" if remove_if_unk:\n",
" return [w for w in chars if w != unknown_index]\n",
" else:\n",
" return chars\n",
"\n",
" else:\n",
" return [idx[tok] for tok in sequence]\n",
"\n",
" def __call__(self, example):\n",
" (country, description, points), taster = example\n",
" vectorized_description = self.vectorize_sequence(description, self.word_to_idx)\n",
" \n",
" unknown_country = self.country_to_idx['']\n",
" vectorized_country = self.country_to_idx.get(country, unknown_country)\n",
" \n",
" vectorized_taster = self.taster_to_idx[taster]\n",
" return (\n",
" (vectorized_country, vectorized_description, points),\n",
" vectorized_taster,\n",
" )\n",
"\n",
"vectorizer = Vectorizer(word_to_idx, country_to_idx, taster_to_idx)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"train_data = [vectorizer(example) for example in train_formatted]\n",
"valid_data = [vectorizer(example) for example in valid_formatted]\n",
"test_data = [vectorizer(example) for example in test_formatted]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((8,\n",
" [29056,\n",
" 21967,\n",
" 1308,\n",
" 18052,\n",
" 27464,\n",
" 2279,\n",
" 14947,\n",
" 21678,\n",
" 31004,\n",
" 21640,\n",
" 2716,\n",
" 47,\n",
" 5379,\n",
" 16304,\n",
" 5024,\n",
" 10933,\n",
" 12934,\n",
" 31371,\n",
" 22165,\n",
" 2430,\n",
" 25992,\n",
" 47,\n",
" 15064,\n",
" 24606,\n",
" 30998,\n",
" 7611,\n",
" 34304,\n",
" 31004,\n",
" 32874,\n",
" 69,\n",
" 31004,\n",
" 22341,\n",
" 16276,\n",
" 3964,\n",
" 2279,\n",
" 6846,\n",
" 47,\n",
" 34304,\n",
" 30165,\n",
" 47,\n",
" 16400,\n",
" 4084,\n",
" 13026,\n",
" 12254,\n",
" 69,\n",
" 18160,\n",
" 47,\n",
" 6785,\n",
" 2279,\n",
" 28080,\n",
" 15064,\n",
" 21678,\n",
" 31004,\n",
" 12043,\n",
" 69],\n",
" 0.87),\n",
" 12)"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Le concept de padding est extrêmement important. Il nous permet d'envoyer des tenseurs de longueurs différentes sur le GPU.\n",
"\n",
"Nous prenons donc le tenseur le plus long de notre minibatch pour créer une matrice d'exemple."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def pad_sequences(vectorized_seqs, seq_lengths):\n",
" seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()\n",
" for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):\n",
" seq_tensor[idx, :seqlen] = torch.LongTensor(seq[:seqlen])\n",
" return seq_tensor\n",
"\n",
"def collate_examples(samples):\n",
" features, tasters = list(zip(*samples))\n",
" countries, descriptions, points = list(zip(*features))\n",
" descriptions_lengths = torch.LongTensor([len(s) for s in descriptions])\n",
" padded_descriptions = pad_sequences(descriptions, descriptions_lengths)\n",
" countries = torch.LongTensor(countries)\n",
" points = torch.FloatTensor(points)\n",
" tasters = torch.LongTensor(tasters)\n",
" return (countries, padded_descriptions, points), tasters"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"batch_size = 64\n",
"\n",
"train_loader = DataLoader(\n",
" train_data,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_examples,\n",
" shuffle=True\n",
")\n",
"\n",
"valid_loader = DataLoader(\n",
" valid_data,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_examples,\n",
" shuffle=False\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" test_data,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_examples,\n",
" shuffle=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((tensor([39, 14, 39, 21, 39, 39, 35, 16, 30, 21, 39, 30, 39, 21, 0, 14, 21, 39,\n",
" 8, 35, 35, 2, 17, 39, 14, 39, 21, 39, 8, 39, 21, 16, 8, 39, 21, 39,\n",
" 14, 39, 39, 39, 21, 39, 36, 14, 14, 39, 14, 14, 39, 39, 21, 39, 21, 39,\n",
" 39, 39, 39, 14, 14, 0, 8, 8, 14, 21]),\n",
" tensor([[31095, 16276, 31004, ..., 0, 0, 0],\n",
" [ 4655, 21081, 2279, ..., 0, 0, 0],\n",
" [34304, 31004, 1580, ..., 0, 0, 0],\n",
" ...,\n",
" [11590, 20781, 21678, ..., 0, 0, 0],\n",
" [31004, 9502, 11641, ..., 0, 0, 0],\n",
" [ 2716, 21508, 25918, ..., 0, 0, 0]]),\n",
" tensor([0.8900, 0.8500, 0.8400, 0.8900, 0.8900, 0.8600, 0.9000, 0.9400, 0.8700,\n",
" 0.9300, 0.8600, 0.8500, 0.9100, 0.9000, 0.8400, 0.8700, 0.8800, 0.8900,\n",
" 0.8600, 0.8600, 0.8900, 0.9200, 0.8500, 0.8600, 0.9400, 0.8900, 0.8500,\n",
" 0.9000, 0.8600, 0.8300, 0.8800, 0.8900, 0.8300, 0.8700, 0.8400, 0.8500,\n",
" 0.8900, 0.9300, 0.8700, 0.9200, 0.8800, 0.8400, 0.8500, 0.9500, 0.8400,\n",
" 0.8800, 0.9300, 0.8700, 0.9300, 0.9200, 0.8800, 0.9100, 0.8800, 0.9000,\n",
" 0.8900, 0.8600, 0.8500, 0.9300, 0.8600, 0.8800, 0.8700, 0.8700, 0.9100,\n",
" 0.8500])),\n",
" tensor([16, 2, 18, 9, 11, 18, 10, 1, 15, 12, 16, 15, 11, 9, 12, 15, 9, 16,\n",
" 12, 17, 17, 8, 17, 0, 15, 16, 9, 11, 12, 14, 9, 1, 12, 14, 9, 14,\n",
" 15, 11, 18, 18, 9, 17, 12, 2, 15, 16, 15, 15, 18, 18, 9, 14, 9, 7,\n",
" 14, 17, 7, 15, 15, 12, 12, 12, 15, 9]))"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = next(iter(train_loader))\n",
"b"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence\n",
"\n",
"class TasterClassifier(nn.Module):\n",
" def __init__(self, word_to_idx, country_to_idx, word_embedding_size, word_hidden_layer_size,\n",
" country_embedding_size, points_hidden_size, hidden_size, num_tasters):\n",
" super(TasterClassifier, self).__init__()\n",
" \n",
" self.word_embeddings = nn.Embedding(len(word_to_idx), word_embedding_size)\n",
" self.word_rnn = nn.LSTM(word_embedding_size, word_hidden_layer_size)\n",
" \n",
" self.country_embeddings = nn.Embedding(len(country_to_idx), country_embedding_size)\n",
" \n",
" self.points_fully_connected = nn.Linear(1, points_hidden_size)\n",
" \n",
" self.fully_connected = nn.Linear(word_hidden_layer_size + country_embedding_size + points_hidden_size, hidden_size)\n",
" \n",
" self.last_fully_connected = nn.Linear(hidden_size, num_tasters)\n",
" \n",
" self.loss_function = nn.CrossEntropyLoss()\n",
" self.metrics = ['acc']\n",
"\n",
" def forward(self, examples):\n",
" \n",
" countries, descriptions, points = examples\n",
" \n",
" # Description handling here\n",
" seq_lengths, perm_idx = (descriptions > 0).sum(dim=1).sort(0, descending=True)\n",
" _, rev_perm_idx = perm_idx.sort(0)\n",
" \n",
" # (batch_size, max_length)\n",
" sorted_descriptions = descriptions[perm_idx]\n",
" \n",
" # (batch_size, max_length, embedding_size)\n",
" embeds = self.word_embeddings(sorted_descriptions)\n",
" packed_descriptions = pack_padded_sequence(embeds, seq_lengths, batch_first=True)\n",
" \n",
" # (1, batch_size, word_hidden_layer_size)\n",
" _, (h_n, _) = self.word_rnn(packed_descriptions)\n",
" h_n = h_n.squeeze(0)\n",
" descriptions_rep = F.relu(h_n[rev_perm_idx])\n",
" \n",
" # Country handling here\n",
" # (batch_size, countries_embeddings)\n",
" countries_embeddings = self.country_embeddings(countries)\n",
" \n",
" # Points handling here\n",
" # (batch_size, points_hidden_size)\n",
" points_rep = F.relu(self.points_fully_connected(points.view(-1, 1)))\n",
" \n",
" # (batch_size, hidden_layer_size + countries_embeddings + points_hidden_size)\n",
" combined_representation = torch.cat([descriptions_rep, countries_embeddings, points_rep], dim=1)\n",
" \n",
" # (batch_size, hidden_size)\n",
" combined_representation = F.relu(self.fully_connected(combined_representation))\n",
" \n",
" # (batch_size, 1)\n",
" out = self.last_fully_connected(combined_representation)\n",
" \n",
" return out.squeeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"loaders = [train_loader, valid_loader, test_loader]"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"def train(name, pytorch_module):\n",
" optimizer = optim.Adam(pytorch_module.parameters(), lr=learning_rate)\n",
" \n",
" # Pytoune Model\n",
" model = Model(pytorch_module, optimizer, pytorch_module.loss_function, metrics=pytorch_module.metrics)\n",
"\n",
" # Send model on GPU\n",
" model.to('cpu')\n",
"\n",
" # Train\n",
" model.fit_generator(train_loader, valid_loader, epochs=n_epoch)\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = TasterClassifier(\n",
" word_to_idx=word_to_idx,\n",
" country_to_idx=country_to_idx,\n",
" word_embedding_size=50,\n",
" word_hidden_layer_size=20,\n",
" country_embedding_size=20,\n",
" points_hidden_size=20,\n",
" hidden_size=50,\n",
" num_tasters=len(taster_to_idx)\n",
")\n",
"model = train('taster_classifier', net)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate_generator(test_loader)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"class TasterClassifierAttn(nn.Module):\n",
" def __init__(self, word_to_idx, country_to_idx, word_embedding_size, common_hidden_size, num_tasters):\n",
" super(TasterClassifierAttn, self).__init__()\n",
" \n",
" self.word_embeddings = nn.Embedding(len(word_to_idx), word_embedding_size)\n",
" self.word_rnn = nn.LSTM(word_embedding_size, common_hidden_size)\n",
" \n",
" self.country_embeddings = nn.Embedding(len(country_to_idx), common_hidden_size)\n",
" \n",
" self.points_fully_connected = nn.Linear(1, common_hidden_size)\n",
" \n",
" self.attn_fully_connected = nn.Linear(common_hidden_size*3, 3)\n",
" \n",
" self.fully_connected = nn.Linear(common_hidden_size, common_hidden_size)\n",
" \n",
" self.last_fully_connected = nn.Linear(common_hidden_size, num_tasters)\n",
" \n",
" self.metrics = [self.acc]\n",
" \n",
" self.attention = False\n",
" \n",
" def loss_function(self, output, y):\n",
" y_pred, attn = output\n",
" return F.cross_entropy(y_pred, y)\n",
" \n",
" def acc(self, output, y):\n",
" y_pred, attn = output\n",
" return acc(y_pred, y)\n",
"\n",
" def forward(self, examples):\n",
" \n",
" countries, descriptions, points = examples\n",
" \n",
" # Description handling here\n",
" seq_lengths, perm_idx = (descriptions > 0).sum(dim=1).sort(0, descending=True)\n",
" _, rev_perm_idx = perm_idx.sort(0)\n",
" \n",
" # (batch_size, max_length)\n",
" sorted_descriptions = descriptions[perm_idx]\n",
" \n",
" # (batch_size, max_length, embedding_size)\n",
" embeds = self.word_embeddings(sorted_descriptions)\n",
" packed_descriptions = pack_padded_sequence(embeds, seq_lengths, batch_first=True)\n",
" \n",
" # (1, batch_size, word_hidden_layer_size)\n",
" _, (h_n, _) = self.word_rnn(packed_descriptions)\n",
" h_n = h_n.squeeze(0)\n",
" descriptions_rep = F.relu(h_n[rev_perm_idx])\n",
" \n",
" # Country handling here\n",
" # (batch_size, countries_embeddings)\n",
" countries_embeddings = self.country_embeddings(countries)\n",
" \n",
" # Points handling here\n",
" # (batch_size, points_hidden_size)\n",
" points_rep = F.relu(self.points_fully_connected(points.view(-1, 1)))\n",
" \n",
" # (batch_size, hidden_layer_size + countries_embeddings + points_hidden_size)\n",
" combined_representation = torch.cat([descriptions_rep, countries_embeddings, points_rep], dim=1)\n",
" \n",
" attn_logits = self.attn_fully_connected(combined_representation)\n",
" attn_pond = F.softmax(attn_logits, dim=1)\n",
" if self.attention:\n",
" attended_input = attn_pond[:, 0].view(-1, 1) * descriptions_rep + attn_pond[:, 1].view(-1, 1) * countries_embeddings + attn_pond[:, 2].view(-1, 1) * points_rep\n",
" else:\n",
" attended_input = descriptions_rep + countries_embeddings + points_rep \n",
" \n",
" # (batch_size, hidden_size)\n",
" combined_representation = F.relu(self.fully_connected(attended_input))\n",
" \n",
" # (batch_size, 1)\n",
" out = self.last_fully_connected(combined_representation)\n",
" \n",
" return out.squeeze(1), attn_pond\n",
"\n",
" \n",
"class AttnActivation(Callback):\n",
" def __init__(self, epoch_start=0):\n",
" super().__init__()\n",
" self.epoch_start = epoch_start\n",
" \n",
" def on_epoch_begin(self, epoch, logs):\n",
" if self.epoch_start == epoch:\n",
" print(\"Activating attention\")\n",
" self.model.model.attention = True\n",
"\n",
"\n",
"class GradientLogging(Callback):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.gradient_logs = defaultdict(list)\n",
" \n",
" def on_backward_end(self, batch):\n",
" # import pdb;pdb.set_trace()\n",
" self.gradient_logs['word_embeddings'].append(self.model.model.word_embeddings.weight.grad.data.norm())\n",
" self.gradient_logs['country_embeddings'].append(self.model.model.country_embeddings.weight.grad.data.norm())\n",
" self.gradient_logs['points'].append(self.model.model.points_fully_connected.weight.grad.data.norm())\n",
" \n",
" def on_epoch_end(self, epoch, logs):\n",
" print(\"Word Embeddings grad norm: {}\".format(np.mean(self.gradient_logs['word_embeddings'])))\n",
" print(\"Country Embeddings grad norm: {}\".format(np.mean(self.gradient_logs['country_embeddings'])))\n",
" print(\"Points grad norm: {}\".format(np.mean(self.gradient_logs['points'])))\n",
" self.gradient_logs['word_embeddings'] = list()\n",
" self.gradient_logs['country_embeddings'] = list()\n",
" self.gradient_logs['points'] = list()\n",
"\n",
" \n",
" \n",
"def train_attn(name, pytorch_module, attn_activation=1):\n",
" optimizer = optim.Adam(pytorch_module.parameters(), lr=learning_rate)\n",
" \n",
" callbacks = [GradientLogging(), AttnActivation(attn_activation)]\n",
" \n",
" # Pytoune Model\n",
" model = Model(pytorch_module, optimizer, pytorch_module.loss_function, metrics=pytorch_module.metrics)\n",
"\n",
" # Send model on GPU\n",
" # model.to(device)\n",
"\n",
" # Train\n",
" model.fit_generator(train_loader, valid_loader, epochs=n_epoch, callbacks=callbacks)\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5Activating attention\n",
"Epoch 1/5 63.58s Step 1104/1104: loss: 0.946356, acc: 64.333168, val_loss: 0.904871, val_acc: 65.525839\n",
"Word Embeddings grad norm: 3.422738518565893e-05\n",
"Country Embeddings grad norm: 0.043636590242385864\n",
"Points grad norm: 0.0001556285424157977\n",
"Epoch 2/5 64.71s Step 1104/1104: loss: 0.907408, acc: 65.395566, val_loss: 0.899876, val_acc: 65.525839\n",
"Word Embeddings grad norm: 0.0014634561957791448\n",
"Country Embeddings grad norm: 0.037704743444919586\n",
"Points grad norm: 0.0004377620934974402\n",
"Epoch 3/5 63.48s Step 1104/1104: loss: 0.878339, acc: 65.843190, val_loss: 0.856257, val_acc: 66.829102\n",
"Word Embeddings grad norm: 0.022100228816270828\n",
"Country Embeddings grad norm: 0.046045031398534775\n",
"Points grad norm: 7.195908983703703e-05\n",
"Epoch 4/5 64.04s Step 1104/1104: loss: 0.841254, acc: 67.027410, val_loss: 0.834168, val_acc: 67.826383\n",
"Word Embeddings grad norm: 0.03915749490261078\n",
"Country Embeddings grad norm: 0.0646233782172203\n",
"Points grad norm: 0.0\n",
"Epoch 5/5 63.40s Step 1104/1104: loss: 0.794653, acc: 69.113960, val_loss: 0.784556, val_acc: 70.002267\n",
"Word Embeddings grad norm: 0.06821389496326447\n",
"Country Embeddings grad norm: 0.07550068199634552\n",
"Points grad norm: 0.0\n"
]
}
],
"source": [
"net = TasterClassifierAttn(\n",
" word_to_idx=word_to_idx,\n",
" country_to_idx=country_to_idx,\n",
" word_embedding_size=50,\n",
" common_hidden_size=20,\n",
" num_tasters=len(taster_to_idx)\n",
")\n",
"n_epoch = 5\n",
"model = train_attn('taster_classifier', net)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5 63.53s Step 1104/1104: loss: 0.702741, acc: 75.511014, val_loss: 0.479272, val_acc: 85.086129\n",
"Word Embeddings grad norm: 0.08628664165735245\n",
"Country Embeddings grad norm: 0.12846454977989197\n",
"Points grad norm: 0.07020499557256699\n",
"Epoch 2/5Activating attention\n",
"Epoch 2/5 65.13s Step 1104/1104: loss: 0.340948, acc: 89.813726, val_loss: 0.300813, val_acc: 91.035811\n",
"Word Embeddings grad norm: 0.13509760797023773\n",
"Country Embeddings grad norm: 0.07974996417760849\n",
"Points grad norm: 0.01175676565617323\n",
"Epoch 3/5 64.27s Step 1104/1104: loss: 0.235133, acc: 93.013670, val_loss: 0.264922, val_acc: 92.327743\n",
"Word Embeddings grad norm: 0.14123034477233887\n",
"Country Embeddings grad norm: 0.07340138405561447\n",
"Points grad norm: 0.008250435814261436\n",
"Epoch 4/5 64.26s Step 1104/1104: loss: 0.199422, acc: 94.105815, val_loss: 0.237391, val_acc: 92.962375\n",
"Word Embeddings grad norm: 0.1365615427494049\n",
"Country Embeddings grad norm: 0.0728415995836258\n",
"Points grad norm: 0.005137351341545582\n",
"Epoch 5/5 64.68s Step 1104/1104: loss: 0.167909, acc: 94.979814, val_loss: 0.225570, val_acc: 93.653672\n",
"Word Embeddings grad norm: 0.1322409063577652\n",
"Country Embeddings grad norm: 0.06858542561531067\n",
"Points grad norm: 0.0022519982885569334\n"
]
}
],
"source": [
"net = TasterClassifierAttn(\n",
" word_to_idx=word_to_idx,\n",
" country_to_idx=country_to_idx,\n",
" word_embedding_size=50,\n",
" common_hidden_size=20,\n",
" num_tasters=len(taster_to_idx)\n",
")\n",
"n_epoch = 5\n",
"model = train_attn('taster_classifier', net, attn_activation=2)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Taster (true) (pred): Lauren Buzzeo Lauren Buzzeo\n",
"(0.022559083998203278) Country: France\n",
"(0.1913205236196518) Points: 0.8999999761581421\n",
"(0.7861202955245972) Description: this is vibrant and expressive , with upfront purple-flower fragrances of violet and iris mixed with fruity tones of boysenberry and raspberry sauce . the texture offers medium tannins , ample acidity and skin-driven fruit flavors . roasted coffee bean and sweet smoke accents mark the lingering finish .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) c\n",
"Taster (true) (pred): Sean P. Sullivan Paul Gregutt\n",
"(0.4225955009460449) Country: US\n",
"(0.5669879913330078) Points: 0.8600000143051147\n",
"(0.010416511446237564) Description: the aromas of char , sawdust and vanilla seem to clash with the notes of herbs , leather and cedar . the fruit flavors are lighter in style with the oak ( 50 % new french along with american ) taking over the show .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) c\n",
"Taster (true) (pred): Roger Voss Roger Voss\n",
"(0.09736345708370209) Country: Austria\n",
"(0.007549024652689695) Points: 0.8600000143051147\n",
"(0.8950875401496887) Description: lively , vital and fresh , this is typical of the new generation of grüners that are made for early drinking . it 's light and tangy , bursting with apple and lemon flavors . screwcap .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) c\n",
"Taster (true) (pred): Anna Lee C. Iijima Anna Lee C. Iijima\n",
"(0.9932239055633545) Country: Germany\n",
"(0.006519009359180927) Points: 0.9100000262260437\n",
"(0.00025722666759975255) Description: offering buoyantly fresh fruit and mineral intensity , this pinot noir rosé is well worth finding in quantity to enjoy year-round , not just as a summer quaffer . zesty acidity and a light , finely textured body are amplified by plump , fresh raspberry and peach flavors . its briskly mineral finish is exhilarating .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) c\n",
"Taster (true) (pred): Roger Voss Roger Voss\n",
"(0.1235518828034401) Country: France\n",
"(0.08739478141069412) Points: 0.9200000166893005\n",
"(0.7890533804893494) Description: bright and fruity , this crisp and fresh wine is full of lively red currants and berries cut with acidity and a tangy , orange peel texture . this is a great summer rosé .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) c\n",
"Taster (true) (pred): Matt Kettmann Virginie Boone\n",
"(0.7471133470535278) Country: US\n",
"(0.25210368633270264) Points: 0.949999988079071\n",
"(0.00078294932609424) Description: hailing from the vineyard that surrounds the historic mt . carmel monastery and vineyard , this weaves together aromas of buttered black cherries , dr pepper , black slate and the slightest tinge of dried herbs . from that hedonistic nose , the blackberry-laced palate tightens around pencil shavings , black-tea leaves , graphite , eucalyptus and utterly balanced experience between savory and sweet .\n",
"> (12)()\n",
"-> for x, y in test_loader:\n",
"(Pdb) q\n"
]
},
{
"ename": "BdbQuit",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0midx_to_taster\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtaster_to_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mcountry\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescription\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpoints\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0midx_to_taster\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtaster_to_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mcountry\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescription\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpoints\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mBdbQuit\u001b[0m: "
]
}
],
"source": [
"test_loader = DataLoader(\n",
" test_data,\n",
" batch_size=1,\n",
" collate_fn=collate_examples,\n",
" shuffle=False\n",
")\n",
"\n",
"idx_to_country = {v: k for k, v in country_to_idx.items()}\n",
"idx_to_word = {v: k for k, v in word_to_idx.items()}\n",
"idx_to_taster = {v: k for k, v in taster_to_idx.items()}\n",
"\n",
"for x, y in test_loader:\n",
" country, description, points = x\n",
" pred, attn = model.predict_on_batch(x)\n",
" print(\"Taster (true) (pred): {} {}\".format(idx_to_taster[int(y[0])], idx_to_taster[np.argmax(pred[0])]))\n",
" print(\"({}) Country: {}\".format(float(attn[0][1]), idx_to_country[int(country[0])]))\n",
" print(\"({}) Points: {}\".format(float(attn[0][2]), points[0]))\n",
" print(\"({}) Description: {}\".format(float(attn[0][0]), \" \".join([idx_to_word[int(w)] for w in description[0]])))\n",
" import pdb; pdb.set_trace()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We could we pre-trained word representations for our own model.\n",
"\n",
"Good word embeddings include those of [fasttext](https://fasttext.cc/), [GloVe](https://nlp.stanford.edu/projects/glove/) and [ELMo](https://allennlp.org/elmo).\n",
"\n",
"Here is one way you load word vectors trained using fasttext and creating your own Embedding layer with pytorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gensim.models import KeyedVectors\n",
"\n",
"vec_model_path = './vectors.vec'\n",
"vec_model = KeyedVectors.load_word2vec_format(vec_model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"class MyEmbeddings(nn.Embedding):\n",
" def __init__(self, word_to_idx, embedding_dim):\n",
" super(MyEmbeddings, self).__init__(len(word_to_idx), embedding_dim, padding_idx=0)\n",
" self.embedding_dim = embedding_dim\n",
" self.vocab_size = len(word_to_idx)\n",
" self.word_to_idx = word_to_idx\n",
"\n",
" def set_item_embedding(self, idx, embedding):\n",
" self.weight.data[idx] = torch.FloatTensor(embedding)\n",
"\n",
" def load_words_embeddings(self, vec_model):\n",
" for word in vec_model.index2word:\n",
" if word in self.word_to_idx:\n",
" idx = self.word_to_idx[word]\n",
" embedding = vec_model[word]\n",
" self.set_item_embedding(idx, embedding)\n",
" \n",
"embeddings_layer = MyEmbeddings(dataset['word_to_idx'], vec_model.vector_size)\n",
"embeddings_layer.load_words_embeddings(vec_model)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}