!pip install -Uqq trax
!pip install -Uqq fastcoreImports
import os
import nltk
import re
import string
import itertools
import pandas as pd
import trax
import trax.fastmath.numpy as np
import random as rnd
from trax import layers as tl
from trax import shapes
from trax.supervised import training
from sklearn.model_selection import train_test_split
from fastcore.all import *nltk.download('twitter_samples')
nltk.download('stopwords')
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords, twitter_samples
from nltk.stem import PorterStemmer[nltk_data] Downloading package twitter_samples to /root/nltk_data...
[nltk_data] Package twitter_samples is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data] Package stopwords is already up-to-date!
JAX Autograd
x= np.array(5.0)
xWARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(5., dtype=float32, weak_type=True)
def f(x): return x**2
f(x)DeviceArray(25., dtype=float32, weak_type=True)
grad_f = trax.fastmath.grad(fun=f)
grad_f(x)DeviceArray(10., dtype=float32, weak_type=True)
Data Loading
def get_tweet_df(frac=1):
df_positive = pd.DataFrame(twitter_samples.strings('positive_tweets.json'), columns=['sentence'])
df_positive['sentiment'] = 'positive'
df_negative = pd.DataFrame(twitter_samples.strings('negative_tweets.json'), columns=['sentence'])
df_negative['sentiment'] = 'negative'
df = pd.concat([df_positive, df_negative])
return df.sample(frac=frac).reset_index(drop=True)
df = get_tweet_df()
df| sentence | sentiment | |
|---|---|---|
| 0 | Why is no one awake :( | negative |
| 1 | @JediOMG Id like to think people know who I am :) | positive |
| 2 | and another hour goes by :(( | negative |
| 3 | @beat_trees thank you Beatriz :) | positive |
| 4 | @ashleelynda yes let's do it :) | positive |
| ... | ... | ... |
| 9995 | everytime i look at the clock its 3:02 and som... | positive |
| 9996 | @ReddBlock1 @MadridGamesWeek Too far away :( | negative |
| 9997 | It will all get better in time. :) | positive |
| 9998 | Stats for the week have arrived. 2 new followe... | positive |
| 9999 | 3:33 am and all of my roommates are officially... | negative |
10000 rows × 2 columns
Data Splitting
df_model, df_blind = train_test_split(df, stratify=df['sentiment'])
df_train, df_valid = train_test_split(df_model, stratify=df_model['sentiment'])
df_blind.shape, df_model.shape, df_train.shape, df_valid.shape((2500, 2), (7500, 2), (5625, 2), (1875, 2))
Data Preprocessing
def remove_old_style(tweet): return re.sub(r'^RT[\s]+', '', tweet)
def remove_url(tweet): return re.sub(r'https?://[^\s\n\r]+', '', tweet)
def remove_hash(tweet): return re.sub(r'#', "", tweet)
def remove_numbers(tweet): return re.sub(r'\d*\.?\d+', "", tweet)
tokenizer = TweetTokenizer(preserve_case=False, strip_handles=True, reduce_len=True)
skip_words = stopwords.words('english')+list(string.punctuation)
stemmer = PorterStemmer()
def filter_stem_tokens(tweet_tokens, skip_words=skip_words, stemmer=stemmer):
return [ stemmer.stem(token) for token in tweet_tokens if token not in skip_words]
process_tweet = compose(remove_old_style, remove_url, remove_hash, remove_numbers, tokenizer.tokenize, filter_stem_tokens)tweet = df_train.loc[df_train.index[1000],'sentence']
tweet, process_tweet(tweet)('@My_Old_Dutch ahh I did thank you so much :) xxx',
['ahh', 'thank', 'much', ':)', 'xxx'])
Calculate Vocab and Encode a tweet
df_train['sentence'].apply(process_tweet)484 [inde, ran, charact, knew, get, :)]
7733 [cute, :(]
3743 [yay, weekend', happi, punt, everyon, :d, hors...
6276 [that', one, sweetest, thing, i'v, ever, said,...
1411 [goodnight, love, luke, heart, :-), love, 💜]
...
8825 [happi, princess, today, :(, sigh, get, work, ...
4994 [went, see, appar, use, gb, extra, data, :(]
8112 [followfriday, top, influenc, commun, week, :)]
9961 [sr, financi, analyst, expedia, inc, bellevu, ...
499 [i'v, support, sinc, start, alway, ignor, :(, ...
Name: sentence, Length: 5625, dtype: object
['__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum()))['__PAD__',
'__</e>__',
'__UNK__',
'stand',
'cross',
'minut',
'itna',
'chillin',
'dubai',
'documentari',
'ulti',
'sg',
'episod',
'na',
'blood',
'chorong',
'jami',
'boomshot',
'discharg',
'icepack',
'polaroid',
'tutori',
'lay',
'ene',
'swiftma',
'map',
'marin',
'tirth',
'roux',
'isra',
'scar',
'hahah',
"mine'",
'westandwithik',
'heart',
'dissappear',
'sandcastl',
'shravan',
'experttradesmen',
'lest',
'vocal',
"how'r",
'alex',
'mpoint',
'ferdou',
'scrape',
'nemesi',
'benson',
'you-and',
'sued',
'nisrina',
'blame',
'inspir',
'humanist',
'estat',
'fyn',
'es',
'alway',
'nom',
'get',
'plss',
'vivian',
'buckl',
'migrain',
'lor',
'expect',
'goalscor',
'engin',
'bo',
'aich',
'phenomen',
'argu',
'guis',
'unhappi',
'anymoree',
'worldwid',
'foam',
'<---',
'teret',
'definit',
'benedictervent',
'tava',
'inglewood',
'em',
'swore',
'closet',
'in-hous',
'wait',
'shiatsu',
'nhi',
'indonesian',
'egg',
'عن',
'colombia',
'usernam',
'abligaverin',
'gloucestershir',
'bat',
'xoxo',
'realis',
'shamil',
'cutest',
'fbc',
'apink',
'netbal',
'desir',
'corinehurleigh',
'baloch',
'guzel',
'eagl',
'that',
'miser',
'slice',
'maker',
'powerpoint',
'poison',
'recogn',
'tamesid',
'apma',
'vertigo',
"robert'",
'👏',
'pale',
'paysafecard',
'lovey',
'legend',
'updat',
'snap',
'beach',
'fraand',
'petjam',
'throw',
'histori',
'thermal',
'canadian',
'need',
'crueler',
'rlli',
'sooth',
'footi',
'figur',
'school',
'hannah',
'engag',
'tea',
'gif',
'kidschoiceaward',
'telecom',
'heel',
'still',
'alumni',
'iran',
'yaad',
'🐒',
'dress',
'oooouch',
'leg',
'jaw',
'maurya',
'intellig',
'albay',
'ouch',
'newbi',
'ser',
'petrofac',
'abl',
'damag',
'upp',
'pen',
'vm',
'happenend',
'without',
'daddi',
'adida',
'anymor',
'raven',
'initi',
'agaaain',
'vs',
'hous',
'plz',
'mental',
'whatev',
'edel',
'raheel',
'doubt',
'ada',
'johnni',
'critic',
'horrend',
'absent',
'╲',
'concert',
'boughi',
'pakighinabi',
'kingdom',
'gamedev',
'biggest',
'daze',
'bullshit',
'✵',
'renam',
'choreograph',
'friendli',
'china',
'pad',
'cabl',
'strike',
'folk',
'bi-polar',
'bir',
'huhuhuhu',
'yellow',
'💔',
'secur',
'mart',
'draw',
'teamr',
'더쇼',
'till',
'bow',
'stash',
'docopenhagen',
'erm',
'mubbarak',
'graphicdesign',
'berangkaat',
'bhk',
'inspit',
'rod',
'femin',
'password',
'laugh',
'exo',
'otherwis',
'info',
'glitch',
'aliv',
'worthwhil',
'horseracingtip',
'disgust',
'mexican',
'falkland',
'atleast',
'etienn',
'tweeti',
'gurgaon',
'jen',
'bathtub',
'owli',
'halesowen',
'thug',
'theme',
'mockingjay',
'memor',
'woe',
'achiev',
"lady'",
'industri',
'bright',
'batb',
"we'll",
'monkey',
'halo',
'ahmad',
'sem',
'modifi',
'theori',
'insonia',
'wut',
'tukutanemombasa',
'educ',
"s'okay",
'robert',
'photoset',
'✔',
'tho',
'bittersweet',
'sayanggg',
'persuad',
'evolut',
'häirførc',
'hit',
'fifth',
'bajrangibhaijaanhighestweek',
'omg',
'record',
'bubbl',
'sketchbook',
'man',
'dead',
'curv',
'submit',
'fabul',
'di',
'dreamteam',
'happycustom',
'ahh',
'full',
'😄',
'naruto',
'bbi',
'nichola',
'topgear',
'kaya',
'abhi',
"god'",
'peanut',
'sakho',
'tune',
'reject',
'discoveri',
'loiyal',
'tya',
'diner',
'pip',
'retweet',
'goodmorn',
'guess',
'websit',
'sengenza',
'readystock_hijabmart',
'£',
'indiedev',
'design',
"i'd",
'hogo',
'abby.can',
'abudhabi',
'boyfriend',
'interact',
'forecast',
'muchhh',
'verbal',
'tulan',
'newyork',
'cmc',
'q',
'coyot',
'incredibleindia',
'memori',
'payhour',
'defenc',
'nite',
'videograph',
'samee',
'postcod',
'amber',
'baechyyi',
'confus',
'brief',
'tree',
'fandom',
'noo',
'raspberri',
"it'll",
'stockholm',
'soobin',
'ad',
'weight',
'talk',
"mum'",
'yung',
'gladli',
'hahahah',
'sword',
'poetri',
'hddc',
'valu',
'havuuulovey',
'proud',
'ganna',
'donington',
'lmfaoo',
'jax',
'septemb',
'ganda',
'team',
'chill',
'hitler',
'shahid',
'walkin',
'app',
'pbr',
'diseas',
'friday',
'osea',
'heck',
'mse',
'sent',
'ziam',
'liao',
'mp',
'toe',
'wo',
'launch',
'fikri',
'longgg',
'gawd',
'sheer',
'thencerest',
'zero',
'orcalov',
'lion',
'dancee',
'akshaymostlovedsuperstarev',
'theo',
'howev',
'realli',
'done',
'rest',
'hediy',
'pete',
'contemporari',
'wagga',
'shoe',
'youtub',
'asf',
'explan',
'chao',
'slap',
'group',
'brooklyn',
'jillmil',
'landlord',
'chek',
'star',
'nice',
'dm',
'icecream',
'convinc',
'uhuh',
'scissor',
'queen',
'upload',
'anta',
'hardest',
'dandia',
'catch',
'brienn',
'warlock',
'wanna',
'fulltim',
'opu',
'dire',
'even',
'hitmark',
'aaa',
'shout',
'wru',
'haw',
":')",
'ny',
'garden',
"she'd",
'wrist',
'feedback',
'forget',
'lovenafianna',
'fallen',
'cooler',
'pun',
'kadhafi',
'seungyeon',
'me',
'sometim',
'project',
'dae',
'):',
'adolf',
'crimin',
'kiksext',
'chipotl',
'bu',
'ding',
'fluffi',
'bewar',
'club',
'say',
'stat',
'andi',
'partner',
'bra',
'persona',
'hippo',
'nail',
'perpetu',
'lib',
'toilet',
'�',
'philippin',
'sm',
'thankyoudfor',
'arm',
'thatscold',
'hierro',
'cloth',
'zach',
'kiddo',
'vomit',
'common',
'dewsburi',
'lush',
"spot'",
'admin',
'encor',
'hint',
'incal',
'cultur',
'bless',
'william',
'pizza',
'spin',
'youthcelebr',
'fight',
'giriboy',
'yer',
'bam',
'departur',
'christian',
'sibei',
'braindot',
'robin',
'massag',
'etern',
'hard',
'transpar',
'lucki',
'tasteless',
'ise',
'amount',
'icon',
'srri',
'💪🏻',
'amen',
'financ',
'argo',
'jersey',
'destroy',
'dump',
'gua',
'poland',
'deliv',
'mayb',
'rotat',
'munchkin',
'pure',
'newblogpost',
'eri',
'less',
'cue',
'newdress',
'august',
'janjua',
'mari',
'level',
'jensenackl',
'responsibilti',
'aitchison',
'snippet',
'iv',
'suzan',
'alma',
'music',
'rose',
'jean',
'definetli',
'ri',
'tote',
'ofc',
'ht',
'manush',
'pant',
'cactu',
'pdapaghimok',
'rightnow',
'therealgolden',
'anu',
'fix',
'creepi',
'vacat',
'dv',
'sadlyf',
'orhan',
'simpli',
'life',
'lab',
'least',
'fiver',
'tau',
'ark',
'master',
'yayi',
'everywher',
'scotland',
'peel',
'anthem',
'vip',
'coldest',
'badass',
'experi',
'indiemus',
'miss',
'igbo',
'😁',
'vega',
'imit',
'basic',
'⚡',
'ano',
'gym',
'otp',
'yogyakarta',
'😚',
'lap',
'preorder',
'cheatmat',
'windowsphon',
'cali',
'salman',
"mommy'",
'therver',
'av',
'tension',
'replac',
'girl',
'tonight',
'wave',
'uppar',
'craźi',
'sakin',
'poorli',
'pardon',
'thursday',
'deep',
'tou',
'companion',
'wolrd',
'subject',
'bbmme',
'gottolovebank',
'breconbeacon',
'peasant',
'foot',
'codi',
'hbd',
'stream',
'prey',
'member',
'hull',
'🚂',
'newmus',
'hahahahahahahahahahahahahaha',
'cherri',
'lime',
'hinduism',
'pg',
'journorequest',
"ain't",
'ucan',
'bantim',
'dose',
"dn't",
'offroad',
'getaway',
'stocko',
'felt',
'resourc',
'week',
"y'day",
'burnt',
'nake',
'iwishiknewbett',
'unam',
'earlier',
'tine',
'gensan',
'sizw',
'santo',
'illustr',
'follow',
'tii',
'syria',
'watch',
'yahuda',
'ariel',
'prize',
'vanilla',
'reabsorbt',
'song',
'fixedgearfrenzi',
'imposs',
'nude',
'asleep',
'annoy',
'pothol',
'yaaay',
'huxley',
'health',
'statement',
'ladi',
'stock',
'morn',
'wali',
'award',
'reqd',
'await',
'sanda',
'\U000fe334',
'tk',
'thistl',
'>:d',
'atl',
'homeslic',
'con',
'ahahah',
'state',
'aisyah',
'sleep',
'ive',
'warmup',
'bro',
'believ',
'saphir',
'newslett',
'boyirl',
'democraci',
'netflix',
'foood',
'nawe',
'short',
'basi',
'raini',
'lord',
'hole',
'killer',
'up',
'@artofsleepingin',
'catspj',
'se',
'suspend',
'aqui',
'다쇼',
'buy',
'🎉',
'announc',
'in-app',
'hearess',
'em-con',
'😎',
'kami',
'co-work',
'stage',
'better',
'clap',
'awar',
'niam',
'marti',
'breakfast',
'verfi',
'thent',
'journo',
'kar',
'striker',
'pluckersss',
'skyblock',
"ik'",
'en',
'whattsap',
'donghyuk',
'puppi',
'gotten',
'polit',
'togeth',
'parcel',
'pehli',
"why'd",
'qatarday',
'samosa',
'describ',
'real',
'okayyy',
'hushedpinwithsammi',
'price',
'ohhh',
'tweet',
'cat',
'press',
'moo',
'toast',
'ect',
'improv',
'rise',
'shame',
'utub',
'razzist',
'boohoo',
'<--',
'ourdisneyinfin',
'anonym',
'pixel_daili',
'psycho',
'shipper',
'quarter',
'wast',
'chain',
'shark',
'blue',
'🙌',
'ne',
'jute',
'idaho',
'countless',
'sigh',
'dylan',
'access',
'pull',
'modern',
'enviou',
'yile',
'samsung',
'expand',
'emang',
'diffici',
'georgi',
'kebun',
'energi',
'unlimit',
'wail',
'homi',
'attach',
'ingat',
'brag',
'netfilx',
'kareem',
'taxi',
'nd',
'tato',
'paypoint',
'jazz',
'put',
'along',
'x',
'michel',
'esp',
'quad',
'wwat',
'lukri',
'thebestangkapuso',
'goodwil',
'delight',
'r',
'agonis',
'rapist',
'polici',
'john',
'forgiv',
'jest',
'beatport',
'niend',
'yep',
'avi',
'wakeup',
'earthl',
'nba',
'woot',
'jlo',
'sona',
'flip',
'lang',
'act',
'clara',
'earth',
'sleeep',
'mileston',
'nostalgia',
'trop',
'sheeran',
'lilsisbro',
'ws',
'shoulder',
'colleg',
'fasgadah',
'troy',
'yok',
"women'",
'ducktail',
'guinea',
'gim',
'tbc',
'boo',
'parti',
'train',
'handsom',
'arabia',
'misunderstood',
'fanboy',
'stalker',
'brilliant',
'deennya',
'caramoan',
'vex',
'damn',
'otwolgrandtrail',
'spirit',
'debat',
'ray',
'friskyfiday',
'irrespons',
'’',
'logic',
'depend',
'weltum',
'pretzel',
'steam',
'wowww',
'cranium',
'ack',
'safaa',
'weh',
'sublimin',
'tonn',
'contempl',
'servicewithasmil',
'hopetowin',
'thousand',
'account',
'carriageway',
'fair',
'kbye',
'stop',
'fav',
'💃',
'thigh',
'sb',
'hitter',
'eintlik',
'💖',
'izzi',
'tumblr',
'. .',
'allah',
'cheek',
'pay',
'dg',
'yar',
'higher',
'cotton',
'trespass',
'step',
'ryu',
'potassium',
'😿',
'hide',
'mybrainneedstoshutoff',
'physiotherapi',
'dur',
'cintiq',
'aplomb',
'vin',
'hyperbulli',
'papa',
'elgin',
'bracelet',
'nw',
'tita',
'opt-out',
'cereal',
"d'd",
'weekli',
'rubbish',
'pretend',
'diz',
'sunnyday',
'journey',
'rt',
'spn',
'penyfan',
'south',
'pout',
'religion',
'mixtur',
'yayyy',
'gd',
'tla',
'lesson',
'naruhina',
'tight',
'kikkomansabor',
'becom',
'close',
'ireland',
'heat',
'banda',
'rmtour',
'ejayst',
'/:',
'encanta',
'odoo',
'alright',
...]
df_vocab = pd.DataFrame(['__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum())), columns=['vocab']).reset_index()
df_vocab = df_vocab.set_index('vocab')
df_vocab| index | |
|---|---|
| vocab | |
| __PAD__ | 0 |
| __</e>__ | 1 |
| __UNK__ | 2 |
| stand | 3 |
| cross | 4 |
| ... | ... |
| belov | 6999 |
| arr | 7000 |
| ng | 7001 |
| art | 7002 |
| jule | 7003 |
7004 rows × 1 columns
def get_vocab(df, text_col, cb_process, pad_token='__PAD__', eof_token='__</e>__', unknown_token='__UNK__'):
df_vocab = pd.DataFrame([pad_token, eof_token, unknown_token] + list(set(df[text_col].apply(cb_process).sum())),
columns=['vocab']).reset_index()
df_vocab = df_vocab.set_index('vocab')
return df_vocab
df_vocab = get_vocab(df_train, text_col='sentence', cb_process=process_tweet); df_vocab| index | |
|---|---|
| vocab | |
| __PAD__ | 0 |
| __</e>__ | 1 |
| __UNK__ | 2 |
| stand | 3 |
| cross | 4 |
| ... | ... |
| belov | 6999 |
| arr | 7000 |
| ng | 7001 |
| art | 7002 |
| jule | 7003 |
7004 rows × 1 columns
df['sentence']0 Why is no one awake :(
1 @JediOMG Id like to think people know who I am :)
2 and another hour goes by :((
3 @beat_trees thank you Beatriz :)
4 @ashleelynda yes let's do it :)
...
9995 everytime i look at the clock its 3:02 and som...
9996 @ReddBlock1 @MadridGamesWeek Too far away :(
9997 It will all get better in time. :)
9998 Stats for the week have arrived. 2 new followe...
9999 3:33 am and all of my roommates are officially...
Name: sentence, Length: 10000, dtype: object
msg = "My name is Rahul Saraf :)"
tokens = process_tweet(msg)
tokens['name', 'rahul', 'saraf', ':)']
unknown_token = "__UNK__"
processed_token = [ token if token in df_vocab.index else unknown_token for token in tokens]
processed_token['name', '__UNK__', '__UNK__', ':)']
df_vocab.loc[processed_token]| index | |
|---|---|
| vocab | |
| name | 5423 |
| __UNK__ | 2 |
| __UNK__ | 2 |
| :) | 2877 |
df_vocab.loc[processed_token, 'index'].tolist()[5423, 2, 2, 2877]
def tweet2idx(tweet, cb_process, df_vocab=df_vocab, unknown_token='__UNK__'):
tokens = cb_process(tweet)
processed_token = [token if token in df_vocab.index else unknown_token for token in tokens]
return df_vocab.loc[processed_token, 'index'].tolist()
tweet2idx(msg, cb_process=process_tweet)[5423, 2, 2, 2877]
Data Batching(Generator)
(df_train['sentiment'].value_counts()/len(df_train)).to_dict(){'negative': 0.5000888888888889, 'positive': 0.4999111111111111}
shuffle = True
batch_sz = 4
x_col = 'sentence'
y_col = 'sentiment'
class_dict = {'negative':0, 'positive':1}
class_weights = (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
df = df_train
if shuffle : df = df.sample(frac=1.0)
itr = itertools.cycle(df.iterrows())
next(itr)(2200, sentence I love my body, when it's fuller, and shit, ev...
sentiment positive
Name: 2200, dtype: object)
pad_token = '__PAD__'
pad_id = df_vocab.loc[pad_token, 'index']
pad_id0
batch = (next(itr) for i in range(batch_sz)); list(batch)[(9113, sentence @casslovesdrake yes cass I think you should :)
sentiment positive
Name: 9113, dtype: object),
(7538, sentence Someone talk to me I'm boreddd :(
sentiment negative
Name: 7538, dtype: object),
(2785, sentence Biodiversity, Taxonomic Infrastructure, Intern...
sentiment negative
Name: 2785, dtype: object),
(6754, sentence @TroutAmbush I'm at work :(\nI can show you on...
sentiment negative
Name: 6754, dtype: object)]
batch = (next(itr) for i in range(batch_sz))
X, y = zip(*[(tweet2idx(entry[1][x_col], cb_process=process_tweet, df_vocab=df_vocab, unknown_token=unknown_token), entry[1][y_col]) for entry in batch])
X, y(([4224, 2550, 4207, 1530, 2877],
[931, 3201, 4390, 1039, 5509],
[702, 5549, 4180],
[6583,
1546,
5782,
5314,
2222,
6583,
3549,
4309,
6583,
4349,
11,
1671,
4254,
1329,
5525,
509]),
('positive', 'negative', 'positive', 'negative'))
pd.DataFrame(X)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 4224 | 2550 | 4207 | 1530.0 | 2877.0 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 1 | 931 | 3201 | 4390 | 1039.0 | 5509.0 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 2 | 702 | 5549 | 4180 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 3 | 6583 | 1546 | 5782 | 5314.0 | 2222.0 | 6583.0 | 3549.0 | 4309.0 | 6583.0 | 4349.0 | 11.0 | 1671.0 | 4254.0 | 1329.0 | 5525.0 | 509.0 |
inputs = np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32'); inputsDeviceArray([[4224, 2550, 4207, 1530, 2877, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[ 931, 3201, 4390, 1039, 5509, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[ 702, 5549, 4180, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[6583, 1546, 5782, 5314, 2222, 6583, 3549, 4309, 6583, 4349,
11, 1671, 4254, 1329, 5525, 509]], dtype=int32)
targets = np.array([class_dict[i] for i in y]); targetsDeviceArray([1, 0, 1, 0], dtype=int32)
weights = np.array([class_weights[i] for i in y]); weightsDeviceArray([1., 1., 1., 1.], dtype=float32)
# We assume df_vocab always have an index col. Perhaps we can parameterize it as well.
def text2idx(text, cb_process, df_vocab=df_vocab, unknown_token='__UNK__', vocab_idx_col='index'):
tokens = cb_process(text)
processed_token = [token if token in df_vocab.index else unknown_token for token in tokens]
return df_vocab.loc[processed_token, vocab_idx_col].tolist()
def process_batch(batch, func_text2idx, pad_id, class_dict, class_weights):
X, y = zip(*[(func_text2idx(entry[1][x_col]), entry[1][y_col]) for entry in batch])
inputs = np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32')
targets = np.array([class_dict[i] for i in y])
weights = np.array([class_weights[i] for i in y])
return inputs, targets, weights
def data_generator(df, batch_sz, df_vocab, x_col, y_col, class_dict, class_weights, cb_process,
unknown_token='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index',
shuffle=False, loop=True, stop=False,
):
func_text2idx = lambda text: text2idx(text, cb_process, df_vocab=df_vocab, unknown_token=unknown_token, vocab_idx_col=vocab_idx_col)
pad_id = df_vocab.loc[pad_token, vocab_idx_col]
# unk_id = df_vocab.loc[unknown_token, 'index']
# eof_id = df_vocab.loc[eof_token, 'index']
while not stop:
index = 0 # Resets index to zero on data exhaustion
if shuffle : df = df.sample(frac=1.0) # Shuffles data - only relevant for train not eval tasks
itr = itertools.cycle(df.iterrows()) # Only purpose of cycle here is to handle last batch case when elements of dataset has been exhausted
while index <= len(df):
batch = (next(itr) for i in range(batch_sz))
index += batch_sz
yield process_batch(batch, func_text2idx, pad_id, class_dict, class_weights)
if loop: continue
else: break
g = data_generator(df_train[:10].copy(), 3, df_vocab,
x_col='sentence',
y_col='sentiment',
class_dict={'negative':0, 'positive':1},
class_weights={'negative':1.0, 'positive':1.0},
cb_process=process_tweet
)
next(g)(DeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32),
DeviceArray([1, 0, 1], dtype=int32),
DeviceArray([1., 1., 1.], dtype=float32))
count = 0
while count<6:
batch = next(g)
count +=1
print(batch)(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
[1222, 1148, 2733, 34, 4180, 1148, 3552, 0, 0, 0],
[5441, 4763, 2915, 5453, 2877, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5962, 6958, 4779, 6958, 3651, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5651, 5208, 4343, 718, 2983, 2800, 266, 3749, 1961, 1306,
4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2836, 1148, 1825, 3520, 6388, 2877, 87, 2877],
[6347, 3211, 1899, 4447, 59, 2877, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([1, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32), DeviceArray([1, 0, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
[1222, 1148, 2733, 34, 4180, 1148, 3552, 0, 0, 0],
[5441, 4763, 2915, 5453, 2877, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5962, 6958, 4779, 6958, 3651, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5651, 5208, 4343, 718, 2983, 2800, 266, 3749, 1961, 1306,
4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
Model Definition
emb_dims=256
vocab_sz = df_vocab.shape[0]; vocab_sz7004
embed_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims); embed_layerEmbedding_7004_256
g = data_generator(df_train[:10].copy(), 3, df_vocab,
x_col='sentence',
y_col='sentiment',
class_dict={'negative':0, 'positive':1},
class_weights={'negative':1.0, 'positive':1.0},
cb_process=process_tweet
)
inputs, targets, weights = next(g);inputs, targets, weights
inputsDeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32)
inputs.shape(3, 10)
embed_layer.init(trax.shapes.signature(inputs))
embed_layer(inputs).shape(3, 10, 256)
mean_layer = tl.Mean(axis=1)
mean_layer(embed_layer(inputs)).shape(3, 256)
ml_o = mean_layer(embed_layer(inputs)); ml_oDeviceArray([[ 2.79488266e-02, -1.45270219e-02, 4.40993309e-02,
8.53405427e-03, 1.76605564e-02, -3.46990190e-02,
2.33984273e-02, -1.22762835e-02, 3.24008404e-03,
2.61026900e-02, 4.38510887e-02, -1.34156886e-02,
2.90309023e-02, -1.02531705e-02, -1.12023270e-02,
-2.87333727e-02, -9.31228697e-03, 1.19030094e-02,
-4.86194436e-03, 2.03585587e-02, 2.08210200e-02,
2.36468110e-02, -1.37046902e-02, 5.95771149e-02,
-1.72263887e-02, -2.50128116e-02, 2.75060274e-02,
1.49012199e-02, -3.22965020e-03, 2.67289430e-02,
5.11599965e-02, -1.63981263e-02, 3.89814610e-03,
2.38245309e-04, -4.46219929e-02, -3.15294527e-02,
2.55568903e-02, -6.74413843e-03, 4.38911049e-03,
5.11370413e-02, -6.65329397e-04, -1.93921756e-02,
-4.15456630e-02, -4.39555645e-02, 8.59785639e-03,
-1.46773551e-02, -9.66112036e-03, -1.14421621e-02,
-9.40839574e-03, -2.26056613e-02, 3.71458828e-02,
5.98752610e-02, -3.57376747e-02, 3.96995666e-03,
2.67707705e-02, -3.15900822e-03, 2.09550802e-02,
-1.93517022e-02, 1.05023105e-02, -1.86009407e-02,
2.98489183e-02, -4.90169562e-02, -1.11293094e-02,
-2.30411999e-02, -9.55593865e-03, -2.26994953e-03,
-1.48823233e-02, 2.05163322e-02, 5.47512360e-02,
-3.99000458e-02, 8.83504096e-03, 2.84845773e-02,
-5.04426919e-02, -3.60746942e-02, -9.55942087e-04,
-2.53703184e-02, 3.02115735e-02, -1.57081429e-02,
-3.74010913e-02, 1.76684689e-02, 2.65503470e-02,
3.19637395e-02, -1.10823708e-02, 1.11593930e-02,
5.25960736e-02, -8.53511132e-03, -2.15488132e-02,
3.24181207e-02, 4.10275348e-02, -1.24717234e-02,
2.00908165e-02, -1.86071694e-02, 6.19316101e-02,
8.23778193e-03, 4.11699824e-02, -5.55267148e-02,
2.96739675e-02, -4.94132079e-02, 1.98996719e-02,
-8.68158694e-03, -8.26465338e-03, -1.16721755e-02,
1.24795800e-02, -2.28749793e-02, -1.13831200e-02,
1.38960099e-02, -6.09649681e-02, 1.92927364e-02,
1.27717704e-02, 1.81246158e-02, 2.85487180e-03,
2.23373864e-02, -1.62030017e-04, 3.29689048e-02,
5.63657284e-02, -2.56947521e-02, -2.07725521e-02,
-1.73829421e-02, -8.39774217e-03, -6.28772005e-02,
-2.57897824e-02, 3.51984426e-02, 3.11603304e-02,
1.96733624e-02, 2.02013683e-02, -3.11372615e-02,
-4.72403364e-03, 2.01003700e-02, 1.16788512e-02,
-2.10389961e-02, -2.60720439e-02, 3.46273743e-02,
-8.92939061e-05, 2.81158667e-02, 3.72311915e-03,
-1.57815162e-02, -3.94454040e-02, -6.79905352e-04,
-1.24713592e-02, 3.68128233e-02, -2.74460316e-02,
-2.67537721e-02, 5.98629229e-02, -3.95256244e-02,
8.96685582e-04, -2.48770304e-02, -2.35572760e-03,
3.49813960e-02, -2.19532456e-02, -1.54458312e-02,
-2.51260437e-02, 2.38183625e-02, -1.94952022e-02,
4.12691897e-03, -3.46308365e-03, -3.33104320e-02,
-3.65591981e-02, -3.36227827e-02, 1.73776653e-02,
-4.78204433e-03, -1.89488437e-02, 2.72095297e-02,
2.02283915e-02, -3.93495895e-02, 2.03697141e-02,
3.37545536e-02, -4.36529890e-02, 2.83900145e-02,
-8.81456956e-03, 5.76226832e-03, -2.39776410e-02,
5.01187928e-02, -1.67681854e-02, 4.58664112e-02,
-5.49267307e-02, 5.44885360e-02, 2.31288821e-02,
1.85393970e-02, -3.09647825e-02, -4.17682528e-03,
-1.41973272e-02, -1.80879440e-02, 3.17270006e-03,
-4.11324278e-02, 1.53492196e-02, -1.60732698e-02,
3.17174452e-03, 1.02486191e-02, -6.71328604e-03,
2.56019589e-02, -3.71041335e-02, 2.63536163e-02,
2.43177768e-02, -2.19865423e-02, 1.95260886e-02,
1.24767972e-02, 1.73309650e-02, -5.32603040e-02,
-4.32931334e-02, -2.02363096e-02, -6.08797092e-03,
-2.26413738e-02, -3.22343297e-02, -5.00520468e-02,
2.97480021e-02, 3.54830138e-02, 9.21159226e-04,
-4.26780358e-02, 9.97830089e-03, -6.02463372e-02,
1.70748364e-02, -4.51036617e-02, 5.73712066e-02,
5.09973131e-02, -2.35319715e-02, 3.43024591e-03,
2.99761724e-02, 1.53054269e-02, 2.92918663e-02,
4.02006172e-02, -3.04911137e-02, -4.19789506e-03,
1.43923163e-02, 2.02765651e-02, 4.24498804e-02,
-2.35497188e-02, -2.46628560e-02, 4.21048934e-03,
1.39114326e-02, -1.01034585e-02, -2.35804580e-02,
-5.40269092e-02, -5.64321987e-02, 3.26057486e-02,
3.14488187e-02, 1.99555140e-02, 1.84255000e-02,
-1.17764436e-02, 7.91816972e-03, 1.02443406e-02,
-1.08196680e-02, -3.24096046e-02, -1.28721390e-02,
-9.81023535e-04, -1.97989121e-02, -2.10534465e-02,
1.80622116e-02, 4.31514047e-02, 5.77958301e-03,
-2.23643947e-02, 1.69666093e-02, -8.90544057e-03,
-6.02399595e-02, 4.38689850e-02, 1.03713451e-02,
1.28128221e-02],
[ 3.38993296e-02, 1.35733951e-02, 6.66464493e-02,
2.42346879e-02, 6.26562163e-02, -5.03408797e-02,
3.33895311e-02, -5.17918952e-02, 4.52371016e-02,
4.09152135e-02, 5.93941472e-02, 3.95135675e-03,
4.75698374e-02, 2.36201100e-02, -1.38653982e-02,
-3.50236297e-02, 1.83179639e-02, 8.88660774e-02,
-7.64574809e-03, 4.41893004e-02, 1.90847274e-02,
6.45669326e-02, 1.85167082e-02, 2.84678321e-02,
-2.65649464e-02, -8.91264305e-02, 3.99326384e-02,
4.97101732e-02, 1.65231507e-02, 6.03857115e-02,
5.67554049e-02, -6.60985932e-02, -2.66933031e-02,
-1.59043223e-02, -4.03059348e-02, -9.00496766e-02,
4.25944850e-02, -8.74052662e-03, -1.57386363e-02,
9.68783870e-02, 7.67241279e-03, 7.77393405e-04,
-6.49266019e-02, -5.11372164e-02, -2.04032590e-03,
-4.40331548e-02, -6.01055585e-02, -6.31405488e-02,
-6.24569133e-02, -5.85328154e-02, 7.07470477e-02,
8.17954689e-02, -7.81841502e-02, 5.07697463e-02,
1.92059821e-03, 1.69407763e-02, 2.37641782e-02,
-3.53052393e-02, 2.04827357e-02, -1.37658538e-02,
9.69784185e-02, -8.19189474e-02, -4.75043990e-02,
1.99681353e-02, -8.09543729e-02, 1.88667532e-02,
-6.31931052e-02, 3.94086055e-02, 4.83753942e-02,
-7.28371665e-02, -5.50396442e-02, 7.91042596e-02,
-4.60295789e-02, -5.08559234e-02, 2.92809750e-03,
-7.36618638e-02, 5.63151427e-02, -5.38185835e-02,
-6.12046383e-02, 6.76287487e-02, 5.48021495e-02,
7.16748163e-02, -5.46516180e-02, 8.72445628e-02,
7.69111887e-02, -3.30434181e-02, -6.08568788e-02,
6.71918765e-02, 3.20970826e-02, -2.44096198e-04,
4.92355786e-02, -1.22685982e-02, 6.01240955e-02,
-8.08215048e-03, 6.12488985e-02, -7.07898661e-02,
2.49854894e-03, -7.16750547e-02, -1.98491234e-02,
1.55626601e-02, -1.58036947e-02, -8.61070529e-02,
4.21486534e-02, -4.98561328e-03, 2.41489266e-03,
8.83591920e-02, -9.06186104e-02, 1.35235013e-02,
1.32905021e-02, 3.04432455e-02, 3.98928188e-02,
7.79912621e-02, 2.56175045e-02, 4.51246314e-02,
8.10684338e-02, 8.68480373e-03, -6.25829846e-02,
3.14559937e-02, -3.54814455e-02, -9.09190997e-02,
-3.19108739e-02, 5.67414127e-02, 4.53827828e-02,
4.47570793e-02, 7.71191940e-02, -3.83882076e-02,
-2.63503771e-02, 4.22267616e-02, -3.58217470e-02,
-3.12248915e-02, -4.34958600e-02, 4.00263220e-02,
-4.52244841e-03, 7.26998970e-02, 4.99958470e-02,
-7.20170811e-02, -7.27941990e-02, -3.07344142e-02,
-9.12117660e-02, 4.37249653e-02, -2.96394955e-02,
-8.28814730e-02, 9.10605490e-02, -5.95877841e-02,
2.90400572e-02, 1.20591081e-03, 3.32733952e-02,
-1.69485584e-02, -5.68195172e-02, -4.57010306e-02,
-1.22073814e-02, 3.03008687e-02, -4.01868485e-02,
-5.95153682e-02, 1.27641959e-02, -6.82596192e-02,
-2.72259694e-02, -4.46438715e-02, 5.05713467e-03,
-6.51221117e-03, -5.46841882e-02, 6.58601895e-02,
4.11701612e-02, -1.97417717e-02, 4.66704480e-02,
1.64697226e-02, -3.19867991e-02, 6.52857572e-02,
-3.86448540e-02, 3.57317403e-02, -6.92952424e-02,
7.34883100e-02, 1.75105873e-03, 6.97340593e-02,
-7.71872625e-02, 7.05697387e-02, 8.34564567e-02,
2.78792586e-02, -4.48744111e-02, -8.32851157e-02,
-2.21689213e-02, -5.00433743e-02, 3.84568796e-02,
-7.70757943e-02, 6.88394234e-02, 1.92586612e-02,
-2.58953515e-02, 6.19755462e-02, 3.15835588e-02,
2.51717754e-02, -4.82867360e-02, 4.04246673e-02,
2.39828061e-02, -3.71591412e-02, -3.68225835e-02,
4.88160513e-02, 5.52501865e-02, -5.53951561e-02,
-9.27906577e-03, -2.76789647e-02, 3.80942784e-02,
-7.84017965e-02, -5.56430332e-02, -6.55776560e-02,
8.59722421e-02, 7.21482262e-02, 9.23464913e-03,
-6.77740499e-02, -6.00923551e-03, -8.99679214e-02,
-2.39077192e-02, -1.64173227e-02, 8.12381506e-02,
4.63083312e-02, -8.29748958e-02, -5.96962348e-02,
4.78492901e-02, 2.33227890e-02, 7.23729581e-02,
6.62301108e-02, -4.31413725e-02, 4.76113334e-03,
4.34371941e-02, 4.70355712e-02, 5.13664559e-02,
-3.06457970e-02, -1.48225399e-02, -2.08158791e-03,
8.98614805e-03, 3.12850773e-02, -4.92903404e-02,
-6.79363087e-02, -5.74367456e-02, 4.95679630e-03,
1.30830621e-02, 7.03895241e-02, 6.73718378e-02,
-4.37890477e-02, 1.28442496e-02, -3.94231221e-03,
-4.26534340e-02, -7.27512836e-02, 1.74762588e-02,
6.07303455e-02, -3.55013050e-02, -2.01774128e-02,
-1.53364437e-02, 6.09087944e-02, -9.23513435e-03,
-2.86345761e-02, 7.97292311e-03, 2.85661016e-02,
-5.06623983e-02, 6.88529015e-02, 1.98510122e-02,
5.23267277e-02],
[-1.39780194e-02, 5.57200797e-03, 6.62390841e-03,
1.40526835e-02, 3.70232575e-03, 2.78138611e-02,
-1.17395986e-02, -1.75640993e-02, -1.08815432e-02,
7.73417437e-03, 1.12588918e-02, 6.00743247e-03,
1.92421768e-02, 8.56465939e-03, 1.58472434e-02,
1.13404458e-02, 1.98833868e-02, 9.56519134e-03,
-3.23476531e-02, -3.88078555e-03, 3.16261761e-02,
-6.10976527e-03, -3.43191111e-03, 1.11242314e-03,
-2.01413725e-02, -6.99900975e-03, 4.15012203e-02,
-2.71617379e-02, 2.05251891e-02, 8.74847639e-03,
2.19325479e-02, -1.16352653e-02, -1.16728516e-02,
-4.01860196e-03, 3.24216075e-02, 1.86130144e-02,
-3.73210534e-02, 1.97703447e-02, 1.81674343e-02,
-1.08840652e-02, 1.04624722e-02, -5.08846482e-03,
1.17169162e-02, 7.22383568e-03, -1.81986447e-02,
-3.80121847e-03, -1.47839787e-03, -2.76529342e-02,
-2.40432825e-02, 3.53930630e-02, 5.81764616e-03,
3.42540490e-03, -2.56362092e-02, 1.92029607e-02,
1.59262177e-02, -2.76176818e-02, -2.91440934e-02,
-8.92254896e-03, -1.53544936e-02, 4.30388004e-03,
-4.02754918e-03, 2.11348645e-02, 1.42662963e-02,
3.30992378e-02, 8.72956123e-03, 1.22786937e-02,
1.79827865e-02, -1.13253882e-02, -2.08487618e-03,
-1.95157751e-02, 2.13124137e-02, 8.57930645e-05,
-3.04783266e-02, 1.88843664e-02, -7.84671586e-03,
2.63085775e-03, -1.90286618e-02, 2.09811907e-02,
-3.05278855e-03, 1.59780458e-02, 1.62816681e-02,
1.79238077e-02, -2.98823416e-03, -8.01594462e-03,
-5.62981963e-02, -2.12363396e-02, 3.74881253e-02,
-1.03619248e-02, -8.58019944e-03, 1.30141713e-03,
1.26137082e-02, -1.84016582e-02, -1.00487191e-02,
3.35148461e-02, 1.32365748e-02, -3.28364898e-03,
2.72394177e-02, -1.77865010e-02, 2.78097894e-02,
2.22075824e-02, -1.26399687e-02, 2.73234043e-02,
-1.28249675e-02, 5.34611230e-04, -2.46077161e-02,
-9.22045764e-03, -1.25545887e-02, -3.82161024e-03,
-7.73401558e-03, 2.34106183e-02, 3.65191186e-03,
-4.31962088e-02, 1.92353670e-02, 2.88315117e-02,
-3.85706387e-02, -1.47016747e-02, -8.37306213e-03,
-9.17184539e-03, 3.19922721e-04, -1.68359285e-04,
-1.44446613e-02, 2.08489094e-02, 2.36450788e-02,
3.02283615e-02, -7.01557985e-03, -9.20639187e-03,
7.84213096e-03, 1.18821133e-02, -4.94026951e-03,
-5.78507408e-03, -3.33779231e-02, -2.48717819e-03,
-8.41716677e-03, -2.90851900e-03, -2.25243680e-02,
-3.70287336e-03, -6.62418408e-03, -2.31272709e-02,
-5.77035593e-03, 1.99090093e-02, 2.24502292e-02,
-1.09872520e-02, 7.33999163e-03, 1.29230488e-02,
1.06960991e-02, 5.60026569e-03, 2.85538356e-03,
2.11979337e-02, -1.93031132e-02, -1.03727486e-02,
2.32398100e-02, 1.67589486e-02, 2.64838431e-02,
2.00437512e-02, -3.74251977e-02, -2.31030621e-02,
1.93875786e-02, -4.20942418e-02, 6.15880266e-03,
2.85475887e-02, -7.63351610e-03, -1.42041994e-02,
1.43085094e-02, -1.92951355e-02, 1.49155734e-03,
-1.53728919e-02, 8.10561515e-03, 1.31099194e-03,
7.46731600e-03, 1.77144911e-02, 2.12036073e-02,
1.17687760e-02, 1.14847599e-02, -1.77644342e-02,
3.32744583e-03, 2.84356475e-02, -4.24198173e-02,
3.08298413e-02, 8.84121750e-03, -2.20039161e-03,
-2.52739619e-02, -1.70355421e-02, 4.11101431e-03,
-2.13460531e-02, 1.71343666e-02, -7.55970553e-03,
2.54368167e-02, -2.33681849e-03, 1.42502338e-02,
-1.56443461e-03, -7.66406721e-03, 5.78241143e-03,
-5.79094980e-03, 1.74651798e-02, 8.98183230e-03,
-9.72723216e-03, -4.81641386e-03, 1.29084131e-02,
-9.91731323e-03, 1.88898686e-02, 1.09237237e-02,
-1.23092616e-02, -3.95849172e-04, 3.72255291e-03,
-6.40551466e-03, 5.43562472e-02, -7.50610465e-03,
-6.19069487e-03, 4.09565680e-02, 7.31755805e-04,
1.10384598e-02, -3.49297114e-02, 1.42621633e-03,
5.26248012e-03, 4.01750579e-02, -2.51674652e-02,
2.22904701e-02, 8.94705113e-03, 5.59239602e-03,
1.48024457e-02, 2.17557754e-02, 4.39300761e-02,
-3.35346423e-02, 2.03318152e-04, -1.16438363e-02,
-6.69360999e-03, 5.17346896e-03, -9.16168839e-03,
1.45597383e-04, -1.74349006e-02, -1.51037879e-03,
-9.20751132e-04, -3.40526849e-02, 1.40658682e-02,
1.40288277e-02, -8.62152595e-03, 3.18979137e-02,
-1.24759851e-02, -5.77731524e-03, 1.73657387e-02,
-2.05909293e-02, -4.04520743e-02, 4.15800838e-03,
2.34187152e-02, -1.31293088e-02, -6.62580831e-03,
-2.49841791e-02, 1.33990617e-02, -4.37055156e-03,
4.77700904e-02, 5.86690032e-04, -5.15651703e-03,
-2.12248955e-02, -1.09182587e-02, -9.91707761e-03,
1.73570886e-02]], dtype=float32)
dense_layer = tl.Dense(n_units=2)
dense_layer.init(trax.shapes.signature(ml_o))
dense_layer(ml_o)DeviceArray([[-0.09669925, -0.00369832],
[-0.14038305, -0.01882405],
[-0.08846572, -0.01865664]], dtype=float32)
def classifier(output_dims, vocab_sz, emb_dims=256, mode='train'):
emb_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims)
mean_layer = tl.Mean(axis=1)
dense_output_layer = tl.Dense(n_units=output_dims)
log_softmax_layer = tl.LogSoftmax()
model = tl.Serial(
emb_layer,
mean_layer,
dense_output_layer,
log_softmax_layer
)
return model
classifier(2, df_vocab.shape[0])Serial[
Embedding_7004_256
Mean
Dense_2
LogSoftmax
]
Define Train Eval Task
def get_train_eval_tasks(df_train, df_valid, df_vocab, x_col, y_col, class_dict, class_weights, cb_process, batch_sz=16,
unknown_token='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index'):
rnd.seed(271)
train_task = training.TrainTask(
labeled_data=data_generator(df_train, batch_sz, df_vocab,
x_col=x_col,
y_col=y_col,
class_dict=class_dict,
class_weights=class_weights,
cb_process=cb_process,
unknown_token=unknown_token,
pad_token=pad_token,
eof_token=eof_token,
vocab_idx_col=vocab_idx_col,
shuffle=True,
loop=True,
stop=False
),
loss_layer = tl.WeightedCategoryCrossEntropy(),
optimizer = trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=10)
eval_task = training.EvalTask(
labeled_data=data_generator(df_valid, batch_sz, df_vocab,
x_col=x_col,
y_col=y_col,
class_dict=class_dict,
class_weights=class_weights,
cb_process=cb_process,
unknown_token=unknown_token,
pad_token=pad_token,
eof_token=eof_token,
vocab_idx_col=vocab_idx_col,
shuffle=False,
loop=False,
stop=False
),
metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()]
)
return train_task, eval_tasklen(class_dict)2
Model Training
def train_model(classifier, train_task, eval_task, n_steps, output_dir='.'):
rnd.seed(1234)
training_loop = training.Loop(classifier, train_task, eval_tasks=eval_task, output_dir=output_dir, random_seed=112)
training_loop.run(n_steps=n_steps)
return training_loopx_col = 'sentence'
y_col = 'sentiment'
class_dict = {'negative':0, 'positive':1}
class_weights = (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
batch_sz=16
unknown_token='__UNK__'
pad_token='__PAD__'
eof_token="__</e>__"
cb_process = process_tweet
output_dims = len(class_dict)
emb_dims = 256
output_dir="."
df_model, df_blind = train_test_split(df, stratify=df[y_col])
df_train, df_valid = train_test_split(df_model, stratify=df_model[y_col])
df_vocab = get_vocab(df_train, text_col=x_col, cb_process=cb_process, pad_token=pad_token, eof_token=eof_token, unknown_token=unknown_token)
vocab_sz = df_vocab.shape[0]
vocab_idx_col='index'
cb_process
train_task, eval_task = get_train_eval_tasks(
df_train,
df_valid,
df_vocab,
x_col, y_col, class_dict, class_weights,
cb_process, batch_sz,
unknown_token, pad_token, eof_token,
vocab_idx_col='index') ; train_task, eval_task
model = classifier(output_dims, vocab_sz, emb_dims=emb_dims, mode='train')
training_loop = train_model(model, train_task, eval_task, n_steps=40, output_dir=output_dir)/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
Step 1: Total number of trainable weights: 1235202
Step 1: Ran 1 train steps in 0.94 secs
Step 1: train WeightedCategoryCrossEntropy | 0.69573909
Step 1: eval WeightedCategoryCrossEntropy | 0.63845330
Step 1: eval WeightedCategoryAccuracy | 0.81250000
Step 10: Ran 9 train steps in 3.55 secs
Step 10: train WeightedCategoryCrossEntropy | 0.65314281
Step 10: eval WeightedCategoryCrossEntropy | 0.60192943
Step 10: eval WeightedCategoryAccuracy | 0.87500000
Step 20: Ran 10 train steps in 2.19 secs
Step 20: train WeightedCategoryCrossEntropy | 0.51353288
Step 20: eval WeightedCategoryCrossEntropy | 0.41577297
Step 20: eval WeightedCategoryAccuracy | 1.00000000
Step 30: Ran 10 train steps in 1.23 secs
Step 30: train WeightedCategoryCrossEntropy | 0.26818269
Step 30: eval WeightedCategoryCrossEntropy | 0.14980227
Step 30: eval WeightedCategoryAccuracy | 1.00000000
Step 40: Ran 10 train steps in 1.30 secs
Step 40: train WeightedCategoryCrossEntropy | 0.15570463
Step 40: eval WeightedCategoryCrossEntropy | 0.08368409
Step 40: eval WeightedCategoryAccuracy | 1.00000000
Model Evaluation
blind_gen = data_generator(df_blind, batch_sz, df_vocab,
x_col=x_col,
y_col=y_col,
class_dict=class_dict,
class_weights=class_weights,
cb_process=cb_process,
unknown_token=unknown_token,
pad_token=pad_token,
eof_token=eof_token,
vocab_idx_col=vocab_idx_col,
shuffle=False,
loop=False,
stop=False
)
batch = next(blind_gen)
blind_inputs, blind_targets, blind_weights = batch
training_loop.eval_model(blind_inputs)DeviceArray([[-6.7398906e-02, -2.7306366e+00],
[-6.8514347e-03, -4.9867296e+00],
[-3.4929681e+00, -3.0882478e-02],
[-5.0519943e-02, -3.0105407e+00],
[-2.9234269e+00, -5.5247664e-02],
[-3.8873537e+00, -2.0712495e-02],
[-5.2998271e+00, -5.0048828e-03],
[-3.5535395e+00, -2.9040813e-02],
[-4.6031094e+00, -1.0071278e-02],
[-2.1165485e+00, -1.2834114e-01],
[-4.1328192e-02, -3.2068014e+00],
[-2.3841858e-05, -1.0641651e+01],
[-2.1345377e-02, -3.8575726e+00],
[-2.6583314e-02, -3.6407375e+00],
[-7.7986407e+00, -4.1031837e-04],
[-1.1112690e-02, -4.5052152e+00]], dtype=float32)