Sentiment Analysis

Open In Colab

Imports

!pip install -Uqq trax
!pip install -Uqq fastcore
import os 
import nltk
import re
import string
import itertools
import pandas as pd
import trax
import trax.fastmath.numpy as np
import random as rnd
from trax import layers as tl
from trax import shapes
from trax.supervised import training
from sklearn.model_selection import train_test_split
from fastcore.all import *
nltk.download('twitter_samples')
nltk.download('stopwords')
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords, twitter_samples
from nltk.stem import PorterStemmer
[nltk_data] Downloading package twitter_samples to /root/nltk_data...
[nltk_data]   Package twitter_samples is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!

JAX Autograd

x= np.array(5.0)
x
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(5., dtype=float32, weak_type=True)
def f(x): return x**2
f(x)
DeviceArray(25., dtype=float32, weak_type=True)
grad_f = trax.fastmath.grad(fun=f)
grad_f(x)
DeviceArray(10., dtype=float32, weak_type=True)

Data Loading

def get_tweet_df(frac=1):
  df_positive = pd.DataFrame(twitter_samples.strings('positive_tweets.json'), columns=['sentence'])
  df_positive['sentiment'] = 'positive'
  df_negative = pd.DataFrame(twitter_samples.strings('negative_tweets.json'), columns=['sentence'])
  df_negative['sentiment'] = 'negative'
  df = pd.concat([df_positive, df_negative])
  return df.sample(frac=frac).reset_index(drop=True)

df = get_tweet_df()
df
sentence sentiment
0 Why is no one awake :( negative
1 @JediOMG Id like to think people know who I am :) positive
2 and another hour goes by :(( negative
3 @beat_trees thank you Beatriz :) positive
4 @ashleelynda yes let's do it :) positive
... ... ...
9995 everytime i look at the clock its 3:02 and som... positive
9996 @ReddBlock1 @MadridGamesWeek Too far away :( negative
9997 It will all get better in time. :) positive
9998 Stats for the week have arrived. 2 new followe... positive
9999 3:33 am and all of my roommates are officially... negative

10000 rows ร— 2 columns

Data Splitting

df_model, df_blind = train_test_split(df, stratify=df['sentiment'])
df_train, df_valid = train_test_split(df_model, stratify=df_model['sentiment'])
df_blind.shape, df_model.shape, df_train.shape, df_valid.shape
((2500, 2), (7500, 2), (5625, 2), (1875, 2))

Data Preprocessing

def remove_old_style(tweet): return re.sub(r'^RT[\s]+', '', tweet)
def remove_url(tweet): return re.sub(r'https?://[^\s\n\r]+', '', tweet)
def remove_hash(tweet): return re.sub(r'#', "", tweet)
def remove_numbers(tweet): return re.sub(r'\d*\.?\d+', "", tweet)
tokenizer = TweetTokenizer(preserve_case=False, strip_handles=True, reduce_len=True)
skip_words = stopwords.words('english')+list(string.punctuation)           
stemmer = PorterStemmer() 
def filter_stem_tokens(tweet_tokens, skip_words=skip_words, stemmer=stemmer): 
    return [ stemmer.stem(token) for token in tweet_tokens if token not in skip_words]

process_tweet = compose(remove_old_style, remove_url, remove_hash, remove_numbers, tokenizer.tokenize, filter_stem_tokens)
tweet = df_train.loc[df_train.index[1000],'sentence']
tweet, process_tweet(tweet)
('@My_Old_Dutch ahh I did thank you so much :) xxx',
 ['ahh', 'thank', 'much', ':)', 'xxx'])

Calculate Vocab and Encode a tweet

df_train['sentence'].apply(process_tweet)
484                   [inde, ran, charact, knew, get, :)]
7733                                           [cute, :(]
3743    [yay, weekend', happi, punt, everyon, :d, hors...
6276    [that', one, sweetest, thing, i'v, ever, said,...
1411         [goodnight, love, luke, heart, :-), love, ๐Ÿ’œ]
                              ...                        
8825    [happi, princess, today, :(, sigh, get, work, ...
4994         [went, see, appar, use, gb, extra, data, :(]
8112      [followfriday, top, influenc, commun, week, :)]
9961    [sr, financi, analyst, expedia, inc, bellevu, ...
499     [i'v, support, sinc, start, alway, ignor, :(, ...
Name: sentence, Length: 5625, dtype: object
['__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum()))
['__PAD__',
 '__</e>__',
 '__UNK__',
 'stand',
 'cross',
 'minut',
 'itna',
 'chillin',
 'dubai',
 'documentari',
 'ulti',
 'sg',
 'episod',
 'na',
 'blood',
 'chorong',
 'jami',
 'boomshot',
 'discharg',
 'icepack',
 'polaroid',
 'tutori',
 'lay',
 'ene',
 'swiftma',
 'map',
 'marin',
 'tirth',
 'roux',
 'isra',
 'scar',
 'hahah',
 "mine'",
 'westandwithik',
 'heart',
 'dissappear',
 'sandcastl',
 'shravan',
 'experttradesmen',
 'lest',
 'vocal',
 "how'r",
 'alex',
 'mpoint',
 'ferdou',
 'scrape',
 'nemesi',
 'benson',
 'you-and',
 'sued',
 'nisrina',
 'blame',
 'inspir',
 'humanist',
 'estat',
 'fyn',
 'es',
 'alway',
 'nom',
 'get',
 'plss',
 'vivian',
 'buckl',
 'migrain',
 'lor',
 'expect',
 'goalscor',
 'engin',
 'bo',
 'aich',
 'phenomen',
 'argu',
 'guis',
 'unhappi',
 'anymoree',
 'worldwid',
 'foam',
 '<---',
 'teret',
 'definit',
 'benedictervent',
 'tava',
 'inglewood',
 'em',
 'swore',
 'closet',
 'in-hous',
 'wait',
 'shiatsu',
 'nhi',
 'indonesian',
 'egg',
 'ุนู†',
 'colombia',
 'usernam',
 'abligaverin',
 'gloucestershir',
 'bat',
 'xoxo',
 'realis',
 'shamil',
 'cutest',
 'fbc',
 'apink',
 'netbal',
 'desir',
 'corinehurleigh',
 'baloch',
 'guzel',
 'eagl',
 'that',
 'miser',
 'slice',
 'maker',
 'powerpoint',
 'poison',
 'recogn',
 'tamesid',
 'apma',
 'vertigo',
 "robert'",
 '๐Ÿ‘',
 'pale',
 'paysafecard',
 'lovey',
 'legend',
 'updat',
 'snap',
 'beach',
 'fraand',
 'petjam',
 'throw',
 'histori',
 'thermal',
 'canadian',
 'need',
 'crueler',
 'rlli',
 'sooth',
 'footi',
 'figur',
 'school',
 'hannah',
 'engag',
 'tea',
 'gif',
 'kidschoiceaward',
 'telecom',
 'heel',
 'still',
 'alumni',
 'iran',
 'yaad',
 '๐Ÿ’',
 'dress',
 'oooouch',
 'leg',
 'jaw',
 'maurya',
 'intellig',
 'albay',
 'ouch',
 'newbi',
 'ser',
 'petrofac',
 'abl',
 'damag',
 'upp',
 'pen',
 'vm',
 'happenend',
 'without',
 'daddi',
 'adida',
 'anymor',
 'raven',
 'initi',
 'agaaain',
 'vs',
 'hous',
 'plz',
 'mental',
 'whatev',
 'edel',
 'raheel',
 'doubt',
 'ada',
 'johnni',
 'critic',
 'horrend',
 'absent',
 'โ•ฒ',
 'concert',
 'boughi',
 'pakighinabi',
 'kingdom',
 'gamedev',
 'biggest',
 'daze',
 'bullshit',
 'โœต',
 'renam',
 'choreograph',
 'friendli',
 'china',
 'pad',
 'cabl',
 'strike',
 'folk',
 'bi-polar',
 'bir',
 'huhuhuhu',
 'yellow',
 '๐Ÿ’”',
 'secur',
 'mart',
 'draw',
 'teamr',
 '๋”์‡ผ',
 'till',
 'bow',
 'stash',
 'docopenhagen',
 'erm',
 'mubbarak',
 'graphicdesign',
 'berangkaat',
 'bhk',
 'inspit',
 'rod',
 'femin',
 'password',
 'laugh',
 'exo',
 'otherwis',
 'info',
 'glitch',
 'aliv',
 'worthwhil',
 'horseracingtip',
 'disgust',
 'mexican',
 'falkland',
 'atleast',
 'etienn',
 'tweeti',
 'gurgaon',
 'jen',
 'bathtub',
 'owli',
 'halesowen',
 'thug',
 'theme',
 'mockingjay',
 'memor',
 'woe',
 'achiev',
 "lady'",
 'industri',
 'bright',
 'batb',
 "we'll",
 'monkey',
 'halo',
 'ahmad',
 'sem',
 'modifi',
 'theori',
 'insonia',
 'wut',
 'tukutanemombasa',
 'educ',
 "s'okay",
 'robert',
 'photoset',
 'โœ”',
 'tho',
 'bittersweet',
 'sayanggg',
 'persuad',
 'evolut',
 'hรคirfรธrc',
 'hit',
 'fifth',
 'bajrangibhaijaanhighestweek',
 'omg',
 'record',
 'bubbl',
 'sketchbook',
 'man',
 'dead',
 'curv',
 'submit',
 'fabul',
 'di',
 'dreamteam',
 'happycustom',
 'ahh',
 'full',
 '๐Ÿ˜„',
 'naruto',
 'bbi',
 'nichola',
 'topgear',
 'kaya',
 'abhi',
 "god'",
 'peanut',
 'sakho',
 'tune',
 'reject',
 'discoveri',
 'loiyal',
 'tya',
 'diner',
 'pip',
 'retweet',
 'goodmorn',
 'guess',
 'websit',
 'sengenza',
 'readystock_hijabmart',
 'ยฃ',
 'indiedev',
 'design',
 "i'd",
 'hogo',
 'abby.can',
 'abudhabi',
 'boyfriend',
 'interact',
 'forecast',
 'muchhh',
 'verbal',
 'tulan',
 'newyork',
 'cmc',
 'q',
 'coyot',
 'incredibleindia',
 'memori',
 'payhour',
 'defenc',
 'nite',
 'videograph',
 'samee',
 'postcod',
 'amber',
 'baechyyi',
 'confus',
 'brief',
 'tree',
 'fandom',
 'noo',
 'raspberri',
 "it'll",
 'stockholm',
 'soobin',
 'ad',
 'weight',
 'talk',
 "mum'",
 'yung',
 'gladli',
 'hahahah',
 'sword',
 'poetri',
 'hddc',
 'valu',
 'havuuulovey',
 'proud',
 'ganna',
 'donington',
 'lmfaoo',
 'jax',
 'septemb',
 'ganda',
 'team',
 'chill',
 'hitler',
 'shahid',
 'walkin',
 'app',
 'pbr',
 'diseas',
 'friday',
 'osea',
 'heck',
 'mse',
 'sent',
 'ziam',
 'liao',
 'mp',
 'toe',
 'wo',
 'launch',
 'fikri',
 'longgg',
 'gawd',
 'sheer',
 'thencerest',
 'zero',
 'orcalov',
 'lion',
 'dancee',
 'akshaymostlovedsuperstarev',
 'theo',
 'howev',
 'realli',
 'done',
 'rest',
 'hediy',
 'pete',
 'contemporari',
 'wagga',
 'shoe',
 'youtub',
 'asf',
 'explan',
 'chao',
 'slap',
 'group',
 'brooklyn',
 'jillmil',
 'landlord',
 'chek',
 'star',
 'nice',
 'dm',
 'icecream',
 'convinc',
 'uhuh',
 'scissor',
 'queen',
 'upload',
 'anta',
 'hardest',
 'dandia',
 'catch',
 'brienn',
 'warlock',
 'wanna',
 'fulltim',
 'opu',
 'dire',
 'even',
 'hitmark',
 'aaa',
 'shout',
 'wru',
 'haw',
 ":')",
 'ny',
 'garden',
 "she'd",
 'wrist',
 'feedback',
 'forget',
 'lovenafianna',
 'fallen',
 'cooler',
 'pun',
 'kadhafi',
 'seungyeon',
 '๏ฝ๏ฝ…',
 'sometim',
 'project',
 'dae',
 '):',
 'adolf',
 'crimin',
 'kiksext',
 'chipotl',
 'bu',
 'ding',
 'fluffi',
 'bewar',
 'club',
 'say',
 'stat',
 'andi',
 'partner',
 'bra',
 'persona',
 'hippo',
 'nail',
 'perpetu',
 'lib',
 'toilet',
 '๏ฟฝ',
 'philippin',
 'sm',
 'thankyoudfor',
 'arm',
 'thatscold',
 'hierro',
 'cloth',
 'zach',
 'kiddo',
 'vomit',
 'common',
 'dewsburi',
 'lush',
 "spot'",
 'admin',
 'encor',
 'hint',
 'incal',
 'cultur',
 'bless',
 'william',
 'pizza',
 'spin',
 'youthcelebr',
 'fight',
 'giriboy',
 'yer',
 'bam',
 'departur',
 'christian',
 'sibei',
 'braindot',
 'robin',
 'massag',
 'etern',
 'hard',
 'transpar',
 'lucki',
 'tasteless',
 'ise',
 'amount',
 'icon',
 'srri',
 '๐Ÿ’ช๐Ÿป',
 'amen',
 'financ',
 'argo',
 'jersey',
 'destroy',
 'dump',
 'gua',
 'poland',
 'deliv',
 'mayb',
 'rotat',
 'munchkin',
 'pure',
 'newblogpost',
 'eri',
 'less',
 'cue',
 'newdress',
 'august',
 'janjua',
 'mari',
 'level',
 'jensenackl',
 'responsibilti',
 'aitchison',
 'snippet',
 'iv',
 'suzan',
 'alma',
 'music',
 'rose',
 'jean',
 'definetli',
 'ri',
 'tote',
 'ofc',
 'ht',
 'manush',
 'pant',
 'cactu',
 'pdapaghimok',
 'rightnow',
 'therealgolden',
 'anu',
 'fix',
 'creepi',
 'vacat',
 'dv',
 'sadlyf',
 'orhan',
 'simpli',
 'life',
 'lab',
 'least',
 'fiver',
 'tau',
 'ark',
 'master',
 'yayi',
 'everywher',
 'scotland',
 'peel',
 'anthem',
 'vip',
 'coldest',
 'badass',
 'experi',
 'indiemus',
 'miss',
 'igbo',
 '๐Ÿ˜',
 'vega',
 'imit',
 'basic',
 'โšก',
 'ano',
 'gym',
 'otp',
 'yogyakarta',
 '๐Ÿ˜š',
 'lap',
 'preorder',
 'cheatmat',
 'windowsphon',
 'cali',
 'salman',
 "mommy'",
 'therver',
 'av',
 'tension',
 'replac',
 'girl',
 'tonight',
 'wave',
 'uppar',
 'craลบi',
 'sakin',
 'poorli',
 'pardon',
 'thursday',
 'deep',
 'tou',
 'companion',
 'wolrd',
 'subject',
 'bbmme',
 'gottolovebank',
 'breconbeacon',
 'peasant',
 'foot',
 'codi',
 'hbd',
 'stream',
 'prey',
 'member',
 'hull',
 '๐Ÿš‚',
 'newmus',
 'hahahahahahahahahahahahahaha',
 'cherri',
 'lime',
 'hinduism',
 'pg',
 'journorequest',
 "ain't",
 'ucan',
 'bantim',
 'dose',
 "dn't",
 'offroad',
 'getaway',
 'stocko',
 'felt',
 'resourc',
 'week',
 "y'day",
 'burnt',
 'nake',
 'iwishiknewbett',
 'unam',
 'earlier',
 'tine',
 'gensan',
 'sizw',
 'santo',
 'illustr',
 'follow',
 'tii',
 'syria',
 'watch',
 'yahuda',
 'ariel',
 'prize',
 'vanilla',
 'reabsorbt',
 'song',
 'fixedgearfrenzi',
 'imposs',
 'nude',
 'asleep',
 'annoy',
 'pothol',
 'yaaay',
 'huxley',
 'health',
 'statement',
 'ladi',
 'stock',
 'morn',
 'wali',
 'award',
 'reqd',
 'await',
 'sanda',
 '\U000fe334',
 'tk',
 'thistl',
 '>:d',
 'atl',
 'homeslic',
 'con',
 'ahahah',
 'state',
 'aisyah',
 'sleep',
 'ive',
 'warmup',
 'bro',
 'believ',
 'saphir',
 'newslett',
 'boyirl',
 'democraci',
 'netflix',
 'foood',
 'nawe',
 'short',
 'basi',
 'raini',
 'lord',
 'hole',
 'killer',
 'up',
 '@artofsleepingin',
 'catspj',
 'se',
 'suspend',
 'aqui',
 '๋‹ค์‡ผ',
 'buy',
 '๐ŸŽ‰',
 'announc',
 'in-app',
 'hearess',
 'em-con',
 '๐Ÿ˜Ž',
 'kami',
 'co-work',
 'stage',
 'better',
 'clap',
 'awar',
 'niam',
 'marti',
 'breakfast',
 'verfi',
 'thent',
 'journo',
 'kar',
 'striker',
 'pluckersss',
 'skyblock',
 "ik'",
 'en',
 'whattsap',
 'donghyuk',
 'puppi',
 'gotten',
 'polit',
 'togeth',
 'parcel',
 'pehli',
 "why'd",
 'qatarday',
 'samosa',
 'describ',
 'real',
 'okayyy',
 'hushedpinwithsammi',
 'price',
 'ohhh',
 'tweet',
 'cat',
 'press',
 'moo',
 'toast',
 'ect',
 'improv',
 'rise',
 'shame',
 'utub',
 'razzist',
 'boohoo',
 '<--',
 'ourdisneyinfin',
 'anonym',
 'pixel_daili',
 'psycho',
 'shipper',
 'quarter',
 'wast',
 'chain',
 'shark',
 'blue',
 '๐Ÿ™Œ',
 'ne',
 'jute',
 'idaho',
 'countless',
 'sigh',
 'dylan',
 'access',
 'pull',
 'modern',
 'enviou',
 'yile',
 'samsung',
 'expand',
 'emang',
 'diffici',
 'georgi',
 'kebun',
 'energi',
 'unlimit',
 'wail',
 'homi',
 'attach',
 'ingat',
 'brag',
 'netfilx',
 'kareem',
 'taxi',
 'nd',
 'tato',
 'paypoint',
 'jazz',
 'put',
 'along',
 'x',
 'michel',
 'esp',
 'quad',
 'wwat',
 'lukri',
 'thebestangkapuso',
 'goodwil',
 'delight',
 'r',
 'agonis',
 'rapist',
 'polici',
 'john',
 'forgiv',
 'jest',
 'beatport',
 'niend',
 'yep',
 'avi',
 'wakeup',
 'earthl',
 'nba',
 'woot',
 'jlo',
 'sona',
 'flip',
 'lang',
 'act',
 'clara',
 'earth',
 'sleeep',
 'mileston',
 'nostalgia',
 'trop',
 'sheeran',
 'lilsisbro',
 'ws',
 'shoulder',
 'colleg',
 'fasgadah',
 'troy',
 'yok',
 "women'",
 'ducktail',
 'guinea',
 'gim',
 'tbc',
 'boo',
 'parti',
 'train',
 'handsom',
 'arabia',
 'misunderstood',
 'fanboy',
 'stalker',
 'brilliant',
 'deennya',
 'caramoan',
 'vex',
 'damn',
 'otwolgrandtrail',
 'spirit',
 'debat',
 'ray',
 'friskyfiday',
 'irrespons',
 'โ€™',
 'logic',
 'depend',
 'weltum',
 'pretzel',
 'steam',
 'wowww',
 'cranium',
 'ack',
 'safaa',
 'weh',
 'sublimin',
 'tonn',
 'contempl',
 'servicewithasmil',
 'hopetowin',
 'thousand',
 'account',
 'carriageway',
 'fair',
 'kbye',
 'stop',
 'fav',
 '๐Ÿ’ƒ',
 'thigh',
 'sb',
 'hitter',
 'eintlik',
 '๐Ÿ’–',
 'izzi',
 'tumblr',
 '. .',
 'allah',
 'cheek',
 'pay',
 'dg',
 'yar',
 'higher',
 'cotton',
 'trespass',
 'step',
 'ryu',
 'potassium',
 '๐Ÿ˜ฟ',
 'hide',
 'mybrainneedstoshutoff',
 'physiotherapi',
 'dur',
 'cintiq',
 'aplomb',
 'vin',
 'hyperbulli',
 'papa',
 'elgin',
 'bracelet',
 'nw',
 'tita',
 'opt-out',
 'cereal',
 "d'd",
 'weekli',
 'rubbish',
 'pretend',
 'diz',
 'sunnyday',
 'journey',
 'rt',
 'spn',
 'penyfan',
 'south',
 'pout',
 'religion',
 'mixtur',
 'yayyy',
 'gd',
 'tla',
 'lesson',
 'naruhina',
 'tight',
 'kikkomansabor',
 'becom',
 'close',
 'ireland',
 'heat',
 'banda',
 'rmtour',
 'ejayst',
 '/:',
 'encanta',
 'odoo',
 'alright',
 ...]
df_vocab = pd.DataFrame(['__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum())), columns=['vocab']).reset_index()
df_vocab = df_vocab.set_index('vocab')
df_vocab
index
vocab
__PAD__ 0
__</e>__ 1
__UNK__ 2
stand 3
cross 4
... ...
belov 6999
arr 7000
ng 7001
art 7002
jule 7003

7004 rows ร— 1 columns

def get_vocab(df, text_col, cb_process,  pad_token='__PAD__', eof_token='__</e>__', unknown_token='__UNK__'):
  df_vocab = pd.DataFrame([pad_token, eof_token, unknown_token] + list(set(df[text_col].apply(cb_process).sum())), 
                          columns=['vocab']).reset_index()
  df_vocab = df_vocab.set_index('vocab')
  return df_vocab
          
df_vocab = get_vocab(df_train, text_col='sentence', cb_process=process_tweet); df_vocab
index
vocab
__PAD__ 0
__</e>__ 1
__UNK__ 2
stand 3
cross 4
... ...
belov 6999
arr 7000
ng 7001
art 7002
jule 7003

7004 rows ร— 1 columns

df['sentence']
0                                  Why is no one awake :(
1       @JediOMG Id like to think people know who I am :)
2                            and another hour goes by :((
3                        @beat_trees thank you Beatriz :)
4                         @ashleelynda yes let's do it :)
                              ...                        
9995    everytime i look at the clock its 3:02 and som...
9996         @ReddBlock1 @MadridGamesWeek Too far away :(
9997                   It will all get better in time. :)
9998    Stats for the week have arrived. 2 new followe...
9999    3:33 am and all of my roommates are officially...
Name: sentence, Length: 10000, dtype: object
msg = "My name is Rahul Saraf :)"
tokens = process_tweet(msg)
tokens
['name', 'rahul', 'saraf', ':)']
unknown_token = "__UNK__"
processed_token = [ token if token in df_vocab.index else unknown_token for token in tokens]
processed_token
['name', '__UNK__', '__UNK__', ':)']
df_vocab.loc[processed_token]
index
vocab
name 5423
__UNK__ 2
__UNK__ 2
:) 2877
df_vocab.loc[processed_token, 'index'].tolist()
[5423, 2, 2, 2877]
def tweet2idx(tweet, cb_process, df_vocab=df_vocab, unknown_token='__UNK__'):
  tokens = cb_process(tweet)
  processed_token = [token if token in df_vocab.index else unknown_token for token in tokens]
  return df_vocab.loc[processed_token, 'index'].tolist()

tweet2idx(msg, cb_process=process_tweet)
[5423, 2, 2, 2877]

Data Batching(Generator)

(df_train['sentiment'].value_counts()/len(df_train)).to_dict()
{'negative': 0.5000888888888889, 'positive': 0.4999111111111111}
shuffle = True
batch_sz = 4
x_col = 'sentence'
y_col = 'sentiment'
class_dict = {'negative':0, 'positive':1}

class_weights = (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
df = df_train
if shuffle : df = df.sample(frac=1.0)
itr = itertools.cycle(df.iterrows())
next(itr)
(2200, sentence     I love my body, when it's fuller, and shit, ev...
 sentiment                                             positive
 Name: 2200, dtype: object)
pad_token = '__PAD__'
pad_id = df_vocab.loc[pad_token, 'index']
pad_id
0
batch = (next(itr) for i in range(batch_sz)); list(batch)
[(9113, sentence     @casslovesdrake yes cass I think you should :)
  sentiment                                          positive
  Name: 9113, dtype: object),
 (7538, sentence     Someone talk to me I'm boreddd :(
  sentiment                             negative
  Name: 7538, dtype: object),
 (2785, sentence     Biodiversity, Taxonomic Infrastructure, Intern...
  sentiment                                             negative
  Name: 2785, dtype: object),
 (6754, sentence     @TroutAmbush I'm at work :(\nI can show you on...
  sentiment                                             negative
  Name: 6754, dtype: object)]
batch = (next(itr) for i in range(batch_sz))
X, y = zip(*[(tweet2idx(entry[1][x_col], cb_process=process_tweet, df_vocab=df_vocab, unknown_token=unknown_token), entry[1][y_col]) for entry in batch])
X, y
(([4224, 2550, 4207, 1530, 2877],
  [931, 3201, 4390, 1039, 5509],
  [702, 5549, 4180],
  [6583,
   1546,
   5782,
   5314,
   2222,
   6583,
   3549,
   4309,
   6583,
   4349,
   11,
   1671,
   4254,
   1329,
   5525,
   509]),
 ('positive', 'negative', 'positive', 'negative'))
pd.DataFrame(X)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
0 4224 2550 4207 1530.0 2877.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 931 3201 4390 1039.0 5509.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 702 5549 4180 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 6583 1546 5782 5314.0 2222.0 6583.0 3549.0 4309.0 6583.0 4349.0 11.0 1671.0 4254.0 1329.0 5525.0 509.0
inputs = np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32'); inputs
DeviceArray([[4224, 2550, 4207, 1530, 2877,    0,    0,    0,    0,    0,
                 0,    0,    0,    0,    0,    0],
             [ 931, 3201, 4390, 1039, 5509,    0,    0,    0,    0,    0,
                 0,    0,    0,    0,    0,    0],
             [ 702, 5549, 4180,    0,    0,    0,    0,    0,    0,    0,
                 0,    0,    0,    0,    0,    0],
             [6583, 1546, 5782, 5314, 2222, 6583, 3549, 4309, 6583, 4349,
                11, 1671, 4254, 1329, 5525,  509]], dtype=int32)
targets = np.array([class_dict[i] for i in y]); targets
DeviceArray([1, 0, 1, 0], dtype=int32)
weights = np.array([class_weights[i] for i in y]); weights
DeviceArray([1., 1., 1., 1.], dtype=float32)
# We assume df_vocab always have an index col. Perhaps we can parameterize it as well.
def text2idx(text, cb_process, df_vocab=df_vocab, unknown_token='__UNK__', vocab_idx_col='index'):
  tokens = cb_process(text)
  processed_token = [token if token in df_vocab.index else unknown_token for token in tokens]
  return df_vocab.loc[processed_token, vocab_idx_col].tolist()

def process_batch(batch, func_text2idx, pad_id,  class_dict, class_weights):
  X, y = zip(*[(func_text2idx(entry[1][x_col]), entry[1][y_col]) for entry in batch])
  inputs = np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32')
  targets = np.array([class_dict[i] for i in y])
  weights = np.array([class_weights[i] for i in y])
  return inputs, targets, weights


def data_generator(df, batch_sz, df_vocab, x_col, y_col, class_dict, class_weights, cb_process,
                   unknown_token='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index',
                   shuffle=False, loop=True, stop=False,
                   ):
  func_text2idx = lambda text: text2idx(text, cb_process, df_vocab=df_vocab, unknown_token=unknown_token, vocab_idx_col=vocab_idx_col)
  pad_id = df_vocab.loc[pad_token, vocab_idx_col]
  # unk_id = df_vocab.loc[unknown_token, 'index']
  # eof_id = df_vocab.loc[eof_token, 'index']

  while not stop:
    index = 0                             # Resets index to zero on data exhaustion 
    if shuffle : df = df.sample(frac=1.0) # Shuffles data - only relevant for train not eval tasks
    itr = itertools.cycle(df.iterrows())  # Only purpose of cycle here is to handle last batch case when elements of dataset has been exhausted
    while index <= len(df):
      batch = (next(itr) for i in range(batch_sz))
      index += batch_sz
      yield process_batch(batch, func_text2idx, pad_id, class_dict, class_weights)
    if loop: continue 
    else: break

g = data_generator(df_train[:10].copy(), 3, df_vocab, 
                   x_col='sentence', 
                   y_col='sentiment', 
                   class_dict={'negative':0, 'positive':1},
                   class_weights={'negative':1.0, 'positive':1.0},
                   cb_process=process_tweet
                   )
next(g)
(DeviceArray([[6347, 3211, 1899, 4447,   59, 2877,    0,    0,    0,    0],
              [2281, 5509,    0,    0,    0,    0,    0,    0,    0,    0],
              [1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812,  239]],            dtype=int32),
 DeviceArray([1, 0, 1], dtype=int32),
 DeviceArray([1., 1., 1.], dtype=float32))
count = 0
while count<6:
  batch = next(g)
  count +=1
  print(batch)
(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
             [1222, 1148, 2733,   34, 4180, 1148, 3552,    0,    0,    0],
             [5441, 4763, 2915, 5453, 2877,    0,    0,    0,    0,    0]],            dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509,    0,    0,    0,    0,    0,    0,
                 0,    0,    0,    0],
             [5962, 6958, 4779, 6958, 3651,    0,    0,    0,    0,    0,
                 0,    0,    0,    0],
             [5651, 5208, 4343,  718, 2983, 2800,  266, 3749, 1961, 1306,
              4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2836, 1148, 1825, 3520, 6388, 2877,   87, 2877],
             [6347, 3211, 1899, 4447,   59, 2877,    0,    0],
             [2281, 5509,    0,    0,    0,    0,    0,    0]],            dtype=int32), DeviceArray([1, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6347, 3211, 1899, 4447,   59, 2877,    0,    0,    0,    0],
             [2281, 5509,    0,    0,    0,    0,    0,    0,    0,    0],
             [1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812,  239]],            dtype=int32), DeviceArray([1, 0, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
             [1222, 1148, 2733,   34, 4180, 1148, 3552,    0,    0,    0],
             [5441, 4763, 2915, 5453, 2877,    0,    0,    0,    0,    0]],            dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509,    0,    0,    0,    0,    0,    0,
                 0,    0,    0,    0],
             [5962, 6958, 4779, 6958, 3651,    0,    0,    0,    0,    0,
                 0,    0,    0,    0],
             [5651, 5208, 4343,  718, 2983, 2800,  266, 3749, 1961, 1306,
              4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))

Model Definition

emb_dims=256
vocab_sz = df_vocab.shape[0]; vocab_sz
7004
embed_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims); embed_layer
Embedding_7004_256
g = data_generator(df_train[:10].copy(), 3, df_vocab, 
                   x_col='sentence', 
                   y_col='sentiment', 
                   class_dict={'negative':0, 'positive':1},
                   class_weights={'negative':1.0, 'positive':1.0},
                   cb_process=process_tweet
                   )
inputs, targets, weights = next(g);inputs, targets, weights 

inputs
DeviceArray([[6347, 3211, 1899, 4447,   59, 2877,    0,    0,    0,    0],
             [2281, 5509,    0,    0,    0,    0,    0,    0,    0,    0],
             [1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812,  239]],            dtype=int32)
inputs.shape
(3, 10)
embed_layer.init(trax.shapes.signature(inputs))
embed_layer(inputs).shape
(3, 10, 256)
mean_layer = tl.Mean(axis=1)
mean_layer(embed_layer(inputs)).shape
(3, 256)
ml_o = mean_layer(embed_layer(inputs)); ml_o
DeviceArray([[ 2.79488266e-02, -1.45270219e-02,  4.40993309e-02,
               8.53405427e-03,  1.76605564e-02, -3.46990190e-02,
               2.33984273e-02, -1.22762835e-02,  3.24008404e-03,
               2.61026900e-02,  4.38510887e-02, -1.34156886e-02,
               2.90309023e-02, -1.02531705e-02, -1.12023270e-02,
              -2.87333727e-02, -9.31228697e-03,  1.19030094e-02,
              -4.86194436e-03,  2.03585587e-02,  2.08210200e-02,
               2.36468110e-02, -1.37046902e-02,  5.95771149e-02,
              -1.72263887e-02, -2.50128116e-02,  2.75060274e-02,
               1.49012199e-02, -3.22965020e-03,  2.67289430e-02,
               5.11599965e-02, -1.63981263e-02,  3.89814610e-03,
               2.38245309e-04, -4.46219929e-02, -3.15294527e-02,
               2.55568903e-02, -6.74413843e-03,  4.38911049e-03,
               5.11370413e-02, -6.65329397e-04, -1.93921756e-02,
              -4.15456630e-02, -4.39555645e-02,  8.59785639e-03,
              -1.46773551e-02, -9.66112036e-03, -1.14421621e-02,
              -9.40839574e-03, -2.26056613e-02,  3.71458828e-02,
               5.98752610e-02, -3.57376747e-02,  3.96995666e-03,
               2.67707705e-02, -3.15900822e-03,  2.09550802e-02,
              -1.93517022e-02,  1.05023105e-02, -1.86009407e-02,
               2.98489183e-02, -4.90169562e-02, -1.11293094e-02,
              -2.30411999e-02, -9.55593865e-03, -2.26994953e-03,
              -1.48823233e-02,  2.05163322e-02,  5.47512360e-02,
              -3.99000458e-02,  8.83504096e-03,  2.84845773e-02,
              -5.04426919e-02, -3.60746942e-02, -9.55942087e-04,
              -2.53703184e-02,  3.02115735e-02, -1.57081429e-02,
              -3.74010913e-02,  1.76684689e-02,  2.65503470e-02,
               3.19637395e-02, -1.10823708e-02,  1.11593930e-02,
               5.25960736e-02, -8.53511132e-03, -2.15488132e-02,
               3.24181207e-02,  4.10275348e-02, -1.24717234e-02,
               2.00908165e-02, -1.86071694e-02,  6.19316101e-02,
               8.23778193e-03,  4.11699824e-02, -5.55267148e-02,
               2.96739675e-02, -4.94132079e-02,  1.98996719e-02,
              -8.68158694e-03, -8.26465338e-03, -1.16721755e-02,
               1.24795800e-02, -2.28749793e-02, -1.13831200e-02,
               1.38960099e-02, -6.09649681e-02,  1.92927364e-02,
               1.27717704e-02,  1.81246158e-02,  2.85487180e-03,
               2.23373864e-02, -1.62030017e-04,  3.29689048e-02,
               5.63657284e-02, -2.56947521e-02, -2.07725521e-02,
              -1.73829421e-02, -8.39774217e-03, -6.28772005e-02,
              -2.57897824e-02,  3.51984426e-02,  3.11603304e-02,
               1.96733624e-02,  2.02013683e-02, -3.11372615e-02,
              -4.72403364e-03,  2.01003700e-02,  1.16788512e-02,
              -2.10389961e-02, -2.60720439e-02,  3.46273743e-02,
              -8.92939061e-05,  2.81158667e-02,  3.72311915e-03,
              -1.57815162e-02, -3.94454040e-02, -6.79905352e-04,
              -1.24713592e-02,  3.68128233e-02, -2.74460316e-02,
              -2.67537721e-02,  5.98629229e-02, -3.95256244e-02,
               8.96685582e-04, -2.48770304e-02, -2.35572760e-03,
               3.49813960e-02, -2.19532456e-02, -1.54458312e-02,
              -2.51260437e-02,  2.38183625e-02, -1.94952022e-02,
               4.12691897e-03, -3.46308365e-03, -3.33104320e-02,
              -3.65591981e-02, -3.36227827e-02,  1.73776653e-02,
              -4.78204433e-03, -1.89488437e-02,  2.72095297e-02,
               2.02283915e-02, -3.93495895e-02,  2.03697141e-02,
               3.37545536e-02, -4.36529890e-02,  2.83900145e-02,
              -8.81456956e-03,  5.76226832e-03, -2.39776410e-02,
               5.01187928e-02, -1.67681854e-02,  4.58664112e-02,
              -5.49267307e-02,  5.44885360e-02,  2.31288821e-02,
               1.85393970e-02, -3.09647825e-02, -4.17682528e-03,
              -1.41973272e-02, -1.80879440e-02,  3.17270006e-03,
              -4.11324278e-02,  1.53492196e-02, -1.60732698e-02,
               3.17174452e-03,  1.02486191e-02, -6.71328604e-03,
               2.56019589e-02, -3.71041335e-02,  2.63536163e-02,
               2.43177768e-02, -2.19865423e-02,  1.95260886e-02,
               1.24767972e-02,  1.73309650e-02, -5.32603040e-02,
              -4.32931334e-02, -2.02363096e-02, -6.08797092e-03,
              -2.26413738e-02, -3.22343297e-02, -5.00520468e-02,
               2.97480021e-02,  3.54830138e-02,  9.21159226e-04,
              -4.26780358e-02,  9.97830089e-03, -6.02463372e-02,
               1.70748364e-02, -4.51036617e-02,  5.73712066e-02,
               5.09973131e-02, -2.35319715e-02,  3.43024591e-03,
               2.99761724e-02,  1.53054269e-02,  2.92918663e-02,
               4.02006172e-02, -3.04911137e-02, -4.19789506e-03,
               1.43923163e-02,  2.02765651e-02,  4.24498804e-02,
              -2.35497188e-02, -2.46628560e-02,  4.21048934e-03,
               1.39114326e-02, -1.01034585e-02, -2.35804580e-02,
              -5.40269092e-02, -5.64321987e-02,  3.26057486e-02,
               3.14488187e-02,  1.99555140e-02,  1.84255000e-02,
              -1.17764436e-02,  7.91816972e-03,  1.02443406e-02,
              -1.08196680e-02, -3.24096046e-02, -1.28721390e-02,
              -9.81023535e-04, -1.97989121e-02, -2.10534465e-02,
               1.80622116e-02,  4.31514047e-02,  5.77958301e-03,
              -2.23643947e-02,  1.69666093e-02, -8.90544057e-03,
              -6.02399595e-02,  4.38689850e-02,  1.03713451e-02,
               1.28128221e-02],
             [ 3.38993296e-02,  1.35733951e-02,  6.66464493e-02,
               2.42346879e-02,  6.26562163e-02, -5.03408797e-02,
               3.33895311e-02, -5.17918952e-02,  4.52371016e-02,
               4.09152135e-02,  5.93941472e-02,  3.95135675e-03,
               4.75698374e-02,  2.36201100e-02, -1.38653982e-02,
              -3.50236297e-02,  1.83179639e-02,  8.88660774e-02,
              -7.64574809e-03,  4.41893004e-02,  1.90847274e-02,
               6.45669326e-02,  1.85167082e-02,  2.84678321e-02,
              -2.65649464e-02, -8.91264305e-02,  3.99326384e-02,
               4.97101732e-02,  1.65231507e-02,  6.03857115e-02,
               5.67554049e-02, -6.60985932e-02, -2.66933031e-02,
              -1.59043223e-02, -4.03059348e-02, -9.00496766e-02,
               4.25944850e-02, -8.74052662e-03, -1.57386363e-02,
               9.68783870e-02,  7.67241279e-03,  7.77393405e-04,
              -6.49266019e-02, -5.11372164e-02, -2.04032590e-03,
              -4.40331548e-02, -6.01055585e-02, -6.31405488e-02,
              -6.24569133e-02, -5.85328154e-02,  7.07470477e-02,
               8.17954689e-02, -7.81841502e-02,  5.07697463e-02,
               1.92059821e-03,  1.69407763e-02,  2.37641782e-02,
              -3.53052393e-02,  2.04827357e-02, -1.37658538e-02,
               9.69784185e-02, -8.19189474e-02, -4.75043990e-02,
               1.99681353e-02, -8.09543729e-02,  1.88667532e-02,
              -6.31931052e-02,  3.94086055e-02,  4.83753942e-02,
              -7.28371665e-02, -5.50396442e-02,  7.91042596e-02,
              -4.60295789e-02, -5.08559234e-02,  2.92809750e-03,
              -7.36618638e-02,  5.63151427e-02, -5.38185835e-02,
              -6.12046383e-02,  6.76287487e-02,  5.48021495e-02,
               7.16748163e-02, -5.46516180e-02,  8.72445628e-02,
               7.69111887e-02, -3.30434181e-02, -6.08568788e-02,
               6.71918765e-02,  3.20970826e-02, -2.44096198e-04,
               4.92355786e-02, -1.22685982e-02,  6.01240955e-02,
              -8.08215048e-03,  6.12488985e-02, -7.07898661e-02,
               2.49854894e-03, -7.16750547e-02, -1.98491234e-02,
               1.55626601e-02, -1.58036947e-02, -8.61070529e-02,
               4.21486534e-02, -4.98561328e-03,  2.41489266e-03,
               8.83591920e-02, -9.06186104e-02,  1.35235013e-02,
               1.32905021e-02,  3.04432455e-02,  3.98928188e-02,
               7.79912621e-02,  2.56175045e-02,  4.51246314e-02,
               8.10684338e-02,  8.68480373e-03, -6.25829846e-02,
               3.14559937e-02, -3.54814455e-02, -9.09190997e-02,
              -3.19108739e-02,  5.67414127e-02,  4.53827828e-02,
               4.47570793e-02,  7.71191940e-02, -3.83882076e-02,
              -2.63503771e-02,  4.22267616e-02, -3.58217470e-02,
              -3.12248915e-02, -4.34958600e-02,  4.00263220e-02,
              -4.52244841e-03,  7.26998970e-02,  4.99958470e-02,
              -7.20170811e-02, -7.27941990e-02, -3.07344142e-02,
              -9.12117660e-02,  4.37249653e-02, -2.96394955e-02,
              -8.28814730e-02,  9.10605490e-02, -5.95877841e-02,
               2.90400572e-02,  1.20591081e-03,  3.32733952e-02,
              -1.69485584e-02, -5.68195172e-02, -4.57010306e-02,
              -1.22073814e-02,  3.03008687e-02, -4.01868485e-02,
              -5.95153682e-02,  1.27641959e-02, -6.82596192e-02,
              -2.72259694e-02, -4.46438715e-02,  5.05713467e-03,
              -6.51221117e-03, -5.46841882e-02,  6.58601895e-02,
               4.11701612e-02, -1.97417717e-02,  4.66704480e-02,
               1.64697226e-02, -3.19867991e-02,  6.52857572e-02,
              -3.86448540e-02,  3.57317403e-02, -6.92952424e-02,
               7.34883100e-02,  1.75105873e-03,  6.97340593e-02,
              -7.71872625e-02,  7.05697387e-02,  8.34564567e-02,
               2.78792586e-02, -4.48744111e-02, -8.32851157e-02,
              -2.21689213e-02, -5.00433743e-02,  3.84568796e-02,
              -7.70757943e-02,  6.88394234e-02,  1.92586612e-02,
              -2.58953515e-02,  6.19755462e-02,  3.15835588e-02,
               2.51717754e-02, -4.82867360e-02,  4.04246673e-02,
               2.39828061e-02, -3.71591412e-02, -3.68225835e-02,
               4.88160513e-02,  5.52501865e-02, -5.53951561e-02,
              -9.27906577e-03, -2.76789647e-02,  3.80942784e-02,
              -7.84017965e-02, -5.56430332e-02, -6.55776560e-02,
               8.59722421e-02,  7.21482262e-02,  9.23464913e-03,
              -6.77740499e-02, -6.00923551e-03, -8.99679214e-02,
              -2.39077192e-02, -1.64173227e-02,  8.12381506e-02,
               4.63083312e-02, -8.29748958e-02, -5.96962348e-02,
               4.78492901e-02,  2.33227890e-02,  7.23729581e-02,
               6.62301108e-02, -4.31413725e-02,  4.76113334e-03,
               4.34371941e-02,  4.70355712e-02,  5.13664559e-02,
              -3.06457970e-02, -1.48225399e-02, -2.08158791e-03,
               8.98614805e-03,  3.12850773e-02, -4.92903404e-02,
              -6.79363087e-02, -5.74367456e-02,  4.95679630e-03,
               1.30830621e-02,  7.03895241e-02,  6.73718378e-02,
              -4.37890477e-02,  1.28442496e-02, -3.94231221e-03,
              -4.26534340e-02, -7.27512836e-02,  1.74762588e-02,
               6.07303455e-02, -3.55013050e-02, -2.01774128e-02,
              -1.53364437e-02,  6.09087944e-02, -9.23513435e-03,
              -2.86345761e-02,  7.97292311e-03,  2.85661016e-02,
              -5.06623983e-02,  6.88529015e-02,  1.98510122e-02,
               5.23267277e-02],
             [-1.39780194e-02,  5.57200797e-03,  6.62390841e-03,
               1.40526835e-02,  3.70232575e-03,  2.78138611e-02,
              -1.17395986e-02, -1.75640993e-02, -1.08815432e-02,
               7.73417437e-03,  1.12588918e-02,  6.00743247e-03,
               1.92421768e-02,  8.56465939e-03,  1.58472434e-02,
               1.13404458e-02,  1.98833868e-02,  9.56519134e-03,
              -3.23476531e-02, -3.88078555e-03,  3.16261761e-02,
              -6.10976527e-03, -3.43191111e-03,  1.11242314e-03,
              -2.01413725e-02, -6.99900975e-03,  4.15012203e-02,
              -2.71617379e-02,  2.05251891e-02,  8.74847639e-03,
               2.19325479e-02, -1.16352653e-02, -1.16728516e-02,
              -4.01860196e-03,  3.24216075e-02,  1.86130144e-02,
              -3.73210534e-02,  1.97703447e-02,  1.81674343e-02,
              -1.08840652e-02,  1.04624722e-02, -5.08846482e-03,
               1.17169162e-02,  7.22383568e-03, -1.81986447e-02,
              -3.80121847e-03, -1.47839787e-03, -2.76529342e-02,
              -2.40432825e-02,  3.53930630e-02,  5.81764616e-03,
               3.42540490e-03, -2.56362092e-02,  1.92029607e-02,
               1.59262177e-02, -2.76176818e-02, -2.91440934e-02,
              -8.92254896e-03, -1.53544936e-02,  4.30388004e-03,
              -4.02754918e-03,  2.11348645e-02,  1.42662963e-02,
               3.30992378e-02,  8.72956123e-03,  1.22786937e-02,
               1.79827865e-02, -1.13253882e-02, -2.08487618e-03,
              -1.95157751e-02,  2.13124137e-02,  8.57930645e-05,
              -3.04783266e-02,  1.88843664e-02, -7.84671586e-03,
               2.63085775e-03, -1.90286618e-02,  2.09811907e-02,
              -3.05278855e-03,  1.59780458e-02,  1.62816681e-02,
               1.79238077e-02, -2.98823416e-03, -8.01594462e-03,
              -5.62981963e-02, -2.12363396e-02,  3.74881253e-02,
              -1.03619248e-02, -8.58019944e-03,  1.30141713e-03,
               1.26137082e-02, -1.84016582e-02, -1.00487191e-02,
               3.35148461e-02,  1.32365748e-02, -3.28364898e-03,
               2.72394177e-02, -1.77865010e-02,  2.78097894e-02,
               2.22075824e-02, -1.26399687e-02,  2.73234043e-02,
              -1.28249675e-02,  5.34611230e-04, -2.46077161e-02,
              -9.22045764e-03, -1.25545887e-02, -3.82161024e-03,
              -7.73401558e-03,  2.34106183e-02,  3.65191186e-03,
              -4.31962088e-02,  1.92353670e-02,  2.88315117e-02,
              -3.85706387e-02, -1.47016747e-02, -8.37306213e-03,
              -9.17184539e-03,  3.19922721e-04, -1.68359285e-04,
              -1.44446613e-02,  2.08489094e-02,  2.36450788e-02,
               3.02283615e-02, -7.01557985e-03, -9.20639187e-03,
               7.84213096e-03,  1.18821133e-02, -4.94026951e-03,
              -5.78507408e-03, -3.33779231e-02, -2.48717819e-03,
              -8.41716677e-03, -2.90851900e-03, -2.25243680e-02,
              -3.70287336e-03, -6.62418408e-03, -2.31272709e-02,
              -5.77035593e-03,  1.99090093e-02,  2.24502292e-02,
              -1.09872520e-02,  7.33999163e-03,  1.29230488e-02,
               1.06960991e-02,  5.60026569e-03,  2.85538356e-03,
               2.11979337e-02, -1.93031132e-02, -1.03727486e-02,
               2.32398100e-02,  1.67589486e-02,  2.64838431e-02,
               2.00437512e-02, -3.74251977e-02, -2.31030621e-02,
               1.93875786e-02, -4.20942418e-02,  6.15880266e-03,
               2.85475887e-02, -7.63351610e-03, -1.42041994e-02,
               1.43085094e-02, -1.92951355e-02,  1.49155734e-03,
              -1.53728919e-02,  8.10561515e-03,  1.31099194e-03,
               7.46731600e-03,  1.77144911e-02,  2.12036073e-02,
               1.17687760e-02,  1.14847599e-02, -1.77644342e-02,
               3.32744583e-03,  2.84356475e-02, -4.24198173e-02,
               3.08298413e-02,  8.84121750e-03, -2.20039161e-03,
              -2.52739619e-02, -1.70355421e-02,  4.11101431e-03,
              -2.13460531e-02,  1.71343666e-02, -7.55970553e-03,
               2.54368167e-02, -2.33681849e-03,  1.42502338e-02,
              -1.56443461e-03, -7.66406721e-03,  5.78241143e-03,
              -5.79094980e-03,  1.74651798e-02,  8.98183230e-03,
              -9.72723216e-03, -4.81641386e-03,  1.29084131e-02,
              -9.91731323e-03,  1.88898686e-02,  1.09237237e-02,
              -1.23092616e-02, -3.95849172e-04,  3.72255291e-03,
              -6.40551466e-03,  5.43562472e-02, -7.50610465e-03,
              -6.19069487e-03,  4.09565680e-02,  7.31755805e-04,
               1.10384598e-02, -3.49297114e-02,  1.42621633e-03,
               5.26248012e-03,  4.01750579e-02, -2.51674652e-02,
               2.22904701e-02,  8.94705113e-03,  5.59239602e-03,
               1.48024457e-02,  2.17557754e-02,  4.39300761e-02,
              -3.35346423e-02,  2.03318152e-04, -1.16438363e-02,
              -6.69360999e-03,  5.17346896e-03, -9.16168839e-03,
               1.45597383e-04, -1.74349006e-02, -1.51037879e-03,
              -9.20751132e-04, -3.40526849e-02,  1.40658682e-02,
               1.40288277e-02, -8.62152595e-03,  3.18979137e-02,
              -1.24759851e-02, -5.77731524e-03,  1.73657387e-02,
              -2.05909293e-02, -4.04520743e-02,  4.15800838e-03,
               2.34187152e-02, -1.31293088e-02, -6.62580831e-03,
              -2.49841791e-02,  1.33990617e-02, -4.37055156e-03,
               4.77700904e-02,  5.86690032e-04, -5.15651703e-03,
              -2.12248955e-02, -1.09182587e-02, -9.91707761e-03,
               1.73570886e-02]], dtype=float32)
dense_layer = tl.Dense(n_units=2)
dense_layer.init(trax.shapes.signature(ml_o))
dense_layer(ml_o)
DeviceArray([[-0.09669925, -0.00369832],
             [-0.14038305, -0.01882405],
             [-0.08846572, -0.01865664]], dtype=float32)
def classifier(output_dims, vocab_sz, emb_dims=256, mode='train'):
  emb_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims)
  mean_layer = tl.Mean(axis=1)
  dense_output_layer = tl.Dense(n_units=output_dims)
  log_softmax_layer = tl.LogSoftmax()

  model = tl.Serial(
      emb_layer,
      mean_layer,
      dense_output_layer,
      log_softmax_layer
  )

  return model


classifier(2, df_vocab.shape[0])
Serial[
  Embedding_7004_256
  Mean
  Dense_2
  LogSoftmax
]

Define Train Eval Task

def get_train_eval_tasks(df_train, df_valid, df_vocab, x_col, y_col, class_dict, class_weights, cb_process, batch_sz=16,
                   unknown_token='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index'):
  rnd.seed(271)
  
  train_task = training.TrainTask(
      labeled_data=data_generator(df_train, batch_sz, df_vocab, 
                   x_col=x_col, 
                   y_col=y_col, 
                   class_dict=class_dict,
                   class_weights=class_weights,
                   cb_process=cb_process,
                   unknown_token=unknown_token, 
                   pad_token=pad_token, 
                   eof_token=eof_token, 
                   vocab_idx_col=vocab_idx_col,
                   shuffle=True,
                   loop=True,
                   stop=False
                   ),
      loss_layer = tl.WeightedCategoryCrossEntropy(),
      optimizer = trax.optimizers.Adam(0.01),
      n_steps_per_checkpoint=10)
  
  eval_task = training.EvalTask(
      labeled_data=data_generator(df_valid, batch_sz, df_vocab, 
                   x_col=x_col, 
                   y_col=y_col, 
                   class_dict=class_dict,
                   class_weights=class_weights,
                   cb_process=cb_process,
                   unknown_token=unknown_token, 
                   pad_token=pad_token, 
                   eof_token=eof_token, 
                   vocab_idx_col=vocab_idx_col,
                   shuffle=False,
                   loop=False,
                   stop=False
                   ),
      metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()]
                   )
  return train_task, eval_task
len(class_dict)
2

Model Training

def train_model(classifier, train_task, eval_task, n_steps, output_dir='.'):
  rnd.seed(1234)
  training_loop = training.Loop(classifier, train_task, eval_tasks=eval_task, output_dir=output_dir, random_seed=112)
  training_loop.run(n_steps=n_steps)
  return training_loop
x_col = 'sentence'
y_col = 'sentiment'
class_dict = {'negative':0, 'positive':1}
class_weights = (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
batch_sz=16
unknown_token='__UNK__'
pad_token='__PAD__'
eof_token="__</e>__"
cb_process = process_tweet
output_dims = len(class_dict)
emb_dims = 256
output_dir="."
df_model, df_blind = train_test_split(df, stratify=df[y_col])
df_train, df_valid = train_test_split(df_model, stratify=df_model[y_col])
df_vocab = get_vocab(df_train, text_col=x_col, cb_process=cb_process,  pad_token=pad_token, eof_token=eof_token, unknown_token=unknown_token)
vocab_sz = df_vocab.shape[0]
vocab_idx_col='index'

cb_process
train_task, eval_task = get_train_eval_tasks(
    df_train, 
    df_valid, 
    df_vocab, 
    x_col, y_col, class_dict, class_weights, 
    cb_process, batch_sz,
    unknown_token, pad_token,  eof_token,
    vocab_idx_col='index') ; train_task, eval_task

model =  classifier(output_dims, vocab_sz, emb_dims=emb_dims, mode='train')
training_loop = train_model(model, train_task, eval_task, n_steps=40, output_dir=output_dir)
/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(

Step      1: Total number of trainable weights: 1235202
Step      1: Ran 1 train steps in 0.94 secs
Step      1: train WeightedCategoryCrossEntropy |  0.69573909
Step      1: eval  WeightedCategoryCrossEntropy |  0.63845330
Step      1: eval      WeightedCategoryAccuracy |  0.81250000

Step     10: Ran 9 train steps in 3.55 secs
Step     10: train WeightedCategoryCrossEntropy |  0.65314281
Step     10: eval  WeightedCategoryCrossEntropy |  0.60192943
Step     10: eval      WeightedCategoryAccuracy |  0.87500000

Step     20: Ran 10 train steps in 2.19 secs
Step     20: train WeightedCategoryCrossEntropy |  0.51353288
Step     20: eval  WeightedCategoryCrossEntropy |  0.41577297
Step     20: eval      WeightedCategoryAccuracy |  1.00000000

Step     30: Ran 10 train steps in 1.23 secs
Step     30: train WeightedCategoryCrossEntropy |  0.26818269
Step     30: eval  WeightedCategoryCrossEntropy |  0.14980227
Step     30: eval      WeightedCategoryAccuracy |  1.00000000

Step     40: Ran 10 train steps in 1.30 secs
Step     40: train WeightedCategoryCrossEntropy |  0.15570463
Step     40: eval  WeightedCategoryCrossEntropy |  0.08368409
Step     40: eval      WeightedCategoryAccuracy |  1.00000000

Model Evaluation

blind_gen = data_generator(df_blind, batch_sz, df_vocab, 
                  x_col=x_col, 
                  y_col=y_col, 
                  class_dict=class_dict,
                  class_weights=class_weights,
                  cb_process=cb_process,
                  unknown_token=unknown_token, 
                  pad_token=pad_token, 
                  eof_token=eof_token, 
                  vocab_idx_col=vocab_idx_col,
                  shuffle=False,
                  loop=False,
                  stop=False
                  )

batch = next(blind_gen)
blind_inputs, blind_targets, blind_weights = batch
training_loop.eval_model(blind_inputs)
DeviceArray([[-6.7398906e-02, -2.7306366e+00],
             [-6.8514347e-03, -4.9867296e+00],
             [-3.4929681e+00, -3.0882478e-02],
             [-5.0519943e-02, -3.0105407e+00],
             [-2.9234269e+00, -5.5247664e-02],
             [-3.8873537e+00, -2.0712495e-02],
             [-5.2998271e+00, -5.0048828e-03],
             [-3.5535395e+00, -2.9040813e-02],
             [-4.6031094e+00, -1.0071278e-02],
             [-2.1165485e+00, -1.2834114e-01],
             [-4.1328192e-02, -3.2068014e+00],
             [-2.3841858e-05, -1.0641651e+01],
             [-2.1345377e-02, -3.8575726e+00],
             [-2.6583314e-02, -3.6407375e+00],
             [-7.7986407e+00, -4.1031837e-04],
             [-1.1112690e-02, -4.5052152e+00]], dtype=float32)