!pip install -Uqq trax
!pip install -Uqq fastcore
Imports
import os
import nltk
import re
import string
import itertools
import pandas as pd
import trax
import trax.fastmath.numpy as np
import random as rnd
from trax import layers as tl
from trax import shapes
from trax.supervised import training
from sklearn.model_selection import train_test_split
from fastcore.all import *
'twitter_samples')
nltk.download('stopwords')
nltk.download(from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords, twitter_samples
from nltk.stem import PorterStemmer
[nltk_data] Downloading package twitter_samples to /root/nltk_data...
[nltk_data] Package twitter_samples is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data] Package stopwords is already up-to-date!
JAX Autograd
= np.array(5.0)
x x
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(5., dtype=float32, weak_type=True)
def f(x): return x**2
f(x)
DeviceArray(25., dtype=float32, weak_type=True)
= trax.fastmath.grad(fun=f)
grad_f grad_f(x)
DeviceArray(10., dtype=float32, weak_type=True)
Data Loading
def get_tweet_df(frac=1):
= pd.DataFrame(twitter_samples.strings('positive_tweets.json'), columns=['sentence'])
df_positive 'sentiment'] = 'positive'
df_positive[= pd.DataFrame(twitter_samples.strings('negative_tweets.json'), columns=['sentence'])
df_negative 'sentiment'] = 'negative'
df_negative[= pd.concat([df_positive, df_negative])
df return df.sample(frac=frac).reset_index(drop=True)
= get_tweet_df()
df df
sentence | sentiment | |
---|---|---|
0 | Why is no one awake :( | negative |
1 | @JediOMG Id like to think people know who I am :) | positive |
2 | and another hour goes by :(( | negative |
3 | @beat_trees thank you Beatriz :) | positive |
4 | @ashleelynda yes let's do it :) | positive |
... | ... | ... |
9995 | everytime i look at the clock its 3:02 and som... | positive |
9996 | @ReddBlock1 @MadridGamesWeek Too far away :( | negative |
9997 | It will all get better in time. :) | positive |
9998 | Stats for the week have arrived. 2 new followe... | positive |
9999 | 3:33 am and all of my roommates are officially... | negative |
10000 rows ร 2 columns
Data Splitting
= train_test_split(df, stratify=df['sentiment'])
df_model, df_blind = train_test_split(df_model, stratify=df_model['sentiment'])
df_train, df_valid df_blind.shape, df_model.shape, df_train.shape, df_valid.shape
((2500, 2), (7500, 2), (5625, 2), (1875, 2))
Data Preprocessing
def remove_old_style(tweet): return re.sub(r'^RT[\s]+', '', tweet)
def remove_url(tweet): return re.sub(r'https?://[^\s\n\r]+', '', tweet)
def remove_hash(tweet): return re.sub(r'#', "", tweet)
def remove_numbers(tweet): return re.sub(r'\d*\.?\d+', "", tweet)
= TweetTokenizer(preserve_case=False, strip_handles=True, reduce_len=True)
tokenizer = stopwords.words('english')+list(string.punctuation)
skip_words = PorterStemmer()
stemmer def filter_stem_tokens(tweet_tokens, skip_words=skip_words, stemmer=stemmer):
return [ stemmer.stem(token) for token in tweet_tokens if token not in skip_words]
= compose(remove_old_style, remove_url, remove_hash, remove_numbers, tokenizer.tokenize, filter_stem_tokens) process_tweet
= df_train.loc[df_train.index[1000],'sentence']
tweet tweet, process_tweet(tweet)
('@My_Old_Dutch ahh I did thank you so much :) xxx',
['ahh', 'thank', 'much', ':)', 'xxx'])
Calculate Vocab and Encode a tweet
'sentence'].apply(process_tweet) df_train[
484 [inde, ran, charact, knew, get, :)]
7733 [cute, :(]
3743 [yay, weekend', happi, punt, everyon, :d, hors...
6276 [that', one, sweetest, thing, i'v, ever, said,...
1411 [goodnight, love, luke, heart, :-), love, ๐]
...
8825 [happi, princess, today, :(, sigh, get, work, ...
4994 [went, see, appar, use, gb, extra, data, :(]
8112 [followfriday, top, influenc, commun, week, :)]
9961 [sr, financi, analyst, expedia, inc, bellevu, ...
499 [i'v, support, sinc, start, alway, ignor, :(, ...
Name: sentence, Length: 5625, dtype: object
'__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum())) [
['__PAD__',
'__</e>__',
'__UNK__',
'stand',
'cross',
'minut',
'itna',
'chillin',
'dubai',
'documentari',
'ulti',
'sg',
'episod',
'na',
'blood',
'chorong',
'jami',
'boomshot',
'discharg',
'icepack',
'polaroid',
'tutori',
'lay',
'ene',
'swiftma',
'map',
'marin',
'tirth',
'roux',
'isra',
'scar',
'hahah',
"mine'",
'westandwithik',
'heart',
'dissappear',
'sandcastl',
'shravan',
'experttradesmen',
'lest',
'vocal',
"how'r",
'alex',
'mpoint',
'ferdou',
'scrape',
'nemesi',
'benson',
'you-and',
'sued',
'nisrina',
'blame',
'inspir',
'humanist',
'estat',
'fyn',
'es',
'alway',
'nom',
'get',
'plss',
'vivian',
'buckl',
'migrain',
'lor',
'expect',
'goalscor',
'engin',
'bo',
'aich',
'phenomen',
'argu',
'guis',
'unhappi',
'anymoree',
'worldwid',
'foam',
'<---',
'teret',
'definit',
'benedictervent',
'tava',
'inglewood',
'em',
'swore',
'closet',
'in-hous',
'wait',
'shiatsu',
'nhi',
'indonesian',
'egg',
'ุนู',
'colombia',
'usernam',
'abligaverin',
'gloucestershir',
'bat',
'xoxo',
'realis',
'shamil',
'cutest',
'fbc',
'apink',
'netbal',
'desir',
'corinehurleigh',
'baloch',
'guzel',
'eagl',
'that',
'miser',
'slice',
'maker',
'powerpoint',
'poison',
'recogn',
'tamesid',
'apma',
'vertigo',
"robert'",
'๐',
'pale',
'paysafecard',
'lovey',
'legend',
'updat',
'snap',
'beach',
'fraand',
'petjam',
'throw',
'histori',
'thermal',
'canadian',
'need',
'crueler',
'rlli',
'sooth',
'footi',
'figur',
'school',
'hannah',
'engag',
'tea',
'gif',
'kidschoiceaward',
'telecom',
'heel',
'still',
'alumni',
'iran',
'yaad',
'๐',
'dress',
'oooouch',
'leg',
'jaw',
'maurya',
'intellig',
'albay',
'ouch',
'newbi',
'ser',
'petrofac',
'abl',
'damag',
'upp',
'pen',
'vm',
'happenend',
'without',
'daddi',
'adida',
'anymor',
'raven',
'initi',
'agaaain',
'vs',
'hous',
'plz',
'mental',
'whatev',
'edel',
'raheel',
'doubt',
'ada',
'johnni',
'critic',
'horrend',
'absent',
'โฒ',
'concert',
'boughi',
'pakighinabi',
'kingdom',
'gamedev',
'biggest',
'daze',
'bullshit',
'โต',
'renam',
'choreograph',
'friendli',
'china',
'pad',
'cabl',
'strike',
'folk',
'bi-polar',
'bir',
'huhuhuhu',
'yellow',
'๐',
'secur',
'mart',
'draw',
'teamr',
'๋์ผ',
'till',
'bow',
'stash',
'docopenhagen',
'erm',
'mubbarak',
'graphicdesign',
'berangkaat',
'bhk',
'inspit',
'rod',
'femin',
'password',
'laugh',
'exo',
'otherwis',
'info',
'glitch',
'aliv',
'worthwhil',
'horseracingtip',
'disgust',
'mexican',
'falkland',
'atleast',
'etienn',
'tweeti',
'gurgaon',
'jen',
'bathtub',
'owli',
'halesowen',
'thug',
'theme',
'mockingjay',
'memor',
'woe',
'achiev',
"lady'",
'industri',
'bright',
'batb',
"we'll",
'monkey',
'halo',
'ahmad',
'sem',
'modifi',
'theori',
'insonia',
'wut',
'tukutanemombasa',
'educ',
"s'okay",
'robert',
'photoset',
'โ',
'tho',
'bittersweet',
'sayanggg',
'persuad',
'evolut',
'hรคirfรธrc',
'hit',
'fifth',
'bajrangibhaijaanhighestweek',
'omg',
'record',
'bubbl',
'sketchbook',
'man',
'dead',
'curv',
'submit',
'fabul',
'di',
'dreamteam',
'happycustom',
'ahh',
'full',
'๐',
'naruto',
'bbi',
'nichola',
'topgear',
'kaya',
'abhi',
"god'",
'peanut',
'sakho',
'tune',
'reject',
'discoveri',
'loiyal',
'tya',
'diner',
'pip',
'retweet',
'goodmorn',
'guess',
'websit',
'sengenza',
'readystock_hijabmart',
'ยฃ',
'indiedev',
'design',
"i'd",
'hogo',
'abby.can',
'abudhabi',
'boyfriend',
'interact',
'forecast',
'muchhh',
'verbal',
'tulan',
'newyork',
'cmc',
'q',
'coyot',
'incredibleindia',
'memori',
'payhour',
'defenc',
'nite',
'videograph',
'samee',
'postcod',
'amber',
'baechyyi',
'confus',
'brief',
'tree',
'fandom',
'noo',
'raspberri',
"it'll",
'stockholm',
'soobin',
'ad',
'weight',
'talk',
"mum'",
'yung',
'gladli',
'hahahah',
'sword',
'poetri',
'hddc',
'valu',
'havuuulovey',
'proud',
'ganna',
'donington',
'lmfaoo',
'jax',
'septemb',
'ganda',
'team',
'chill',
'hitler',
'shahid',
'walkin',
'app',
'pbr',
'diseas',
'friday',
'osea',
'heck',
'mse',
'sent',
'ziam',
'liao',
'mp',
'toe',
'wo',
'launch',
'fikri',
'longgg',
'gawd',
'sheer',
'thencerest',
'zero',
'orcalov',
'lion',
'dancee',
'akshaymostlovedsuperstarev',
'theo',
'howev',
'realli',
'done',
'rest',
'hediy',
'pete',
'contemporari',
'wagga',
'shoe',
'youtub',
'asf',
'explan',
'chao',
'slap',
'group',
'brooklyn',
'jillmil',
'landlord',
'chek',
'star',
'nice',
'dm',
'icecream',
'convinc',
'uhuh',
'scissor',
'queen',
'upload',
'anta',
'hardest',
'dandia',
'catch',
'brienn',
'warlock',
'wanna',
'fulltim',
'opu',
'dire',
'even',
'hitmark',
'aaa',
'shout',
'wru',
'haw',
":')",
'ny',
'garden',
"she'd",
'wrist',
'feedback',
'forget',
'lovenafianna',
'fallen',
'cooler',
'pun',
'kadhafi',
'seungyeon',
'๏ฝ๏ฝ
',
'sometim',
'project',
'dae',
'):',
'adolf',
'crimin',
'kiksext',
'chipotl',
'bu',
'ding',
'fluffi',
'bewar',
'club',
'say',
'stat',
'andi',
'partner',
'bra',
'persona',
'hippo',
'nail',
'perpetu',
'lib',
'toilet',
'๏ฟฝ',
'philippin',
'sm',
'thankyoudfor',
'arm',
'thatscold',
'hierro',
'cloth',
'zach',
'kiddo',
'vomit',
'common',
'dewsburi',
'lush',
"spot'",
'admin',
'encor',
'hint',
'incal',
'cultur',
'bless',
'william',
'pizza',
'spin',
'youthcelebr',
'fight',
'giriboy',
'yer',
'bam',
'departur',
'christian',
'sibei',
'braindot',
'robin',
'massag',
'etern',
'hard',
'transpar',
'lucki',
'tasteless',
'ise',
'amount',
'icon',
'srri',
'๐ช๐ป',
'amen',
'financ',
'argo',
'jersey',
'destroy',
'dump',
'gua',
'poland',
'deliv',
'mayb',
'rotat',
'munchkin',
'pure',
'newblogpost',
'eri',
'less',
'cue',
'newdress',
'august',
'janjua',
'mari',
'level',
'jensenackl',
'responsibilti',
'aitchison',
'snippet',
'iv',
'suzan',
'alma',
'music',
'rose',
'jean',
'definetli',
'ri',
'tote',
'ofc',
'ht',
'manush',
'pant',
'cactu',
'pdapaghimok',
'rightnow',
'therealgolden',
'anu',
'fix',
'creepi',
'vacat',
'dv',
'sadlyf',
'orhan',
'simpli',
'life',
'lab',
'least',
'fiver',
'tau',
'ark',
'master',
'yayi',
'everywher',
'scotland',
'peel',
'anthem',
'vip',
'coldest',
'badass',
'experi',
'indiemus',
'miss',
'igbo',
'๐',
'vega',
'imit',
'basic',
'โก',
'ano',
'gym',
'otp',
'yogyakarta',
'๐',
'lap',
'preorder',
'cheatmat',
'windowsphon',
'cali',
'salman',
"mommy'",
'therver',
'av',
'tension',
'replac',
'girl',
'tonight',
'wave',
'uppar',
'craลบi',
'sakin',
'poorli',
'pardon',
'thursday',
'deep',
'tou',
'companion',
'wolrd',
'subject',
'bbmme',
'gottolovebank',
'breconbeacon',
'peasant',
'foot',
'codi',
'hbd',
'stream',
'prey',
'member',
'hull',
'๐',
'newmus',
'hahahahahahahahahahahahahaha',
'cherri',
'lime',
'hinduism',
'pg',
'journorequest',
"ain't",
'ucan',
'bantim',
'dose',
"dn't",
'offroad',
'getaway',
'stocko',
'felt',
'resourc',
'week',
"y'day",
'burnt',
'nake',
'iwishiknewbett',
'unam',
'earlier',
'tine',
'gensan',
'sizw',
'santo',
'illustr',
'follow',
'tii',
'syria',
'watch',
'yahuda',
'ariel',
'prize',
'vanilla',
'reabsorbt',
'song',
'fixedgearfrenzi',
'imposs',
'nude',
'asleep',
'annoy',
'pothol',
'yaaay',
'huxley',
'health',
'statement',
'ladi',
'stock',
'morn',
'wali',
'award',
'reqd',
'await',
'sanda',
'\U000fe334',
'tk',
'thistl',
'>:d',
'atl',
'homeslic',
'con',
'ahahah',
'state',
'aisyah',
'sleep',
'ive',
'warmup',
'bro',
'believ',
'saphir',
'newslett',
'boyirl',
'democraci',
'netflix',
'foood',
'nawe',
'short',
'basi',
'raini',
'lord',
'hole',
'killer',
'up',
'@artofsleepingin',
'catspj',
'se',
'suspend',
'aqui',
'๋ค์ผ',
'buy',
'๐',
'announc',
'in-app',
'hearess',
'em-con',
'๐',
'kami',
'co-work',
'stage',
'better',
'clap',
'awar',
'niam',
'marti',
'breakfast',
'verfi',
'thent',
'journo',
'kar',
'striker',
'pluckersss',
'skyblock',
"ik'",
'en',
'whattsap',
'donghyuk',
'puppi',
'gotten',
'polit',
'togeth',
'parcel',
'pehli',
"why'd",
'qatarday',
'samosa',
'describ',
'real',
'okayyy',
'hushedpinwithsammi',
'price',
'ohhh',
'tweet',
'cat',
'press',
'moo',
'toast',
'ect',
'improv',
'rise',
'shame',
'utub',
'razzist',
'boohoo',
'<--',
'ourdisneyinfin',
'anonym',
'pixel_daili',
'psycho',
'shipper',
'quarter',
'wast',
'chain',
'shark',
'blue',
'๐',
'ne',
'jute',
'idaho',
'countless',
'sigh',
'dylan',
'access',
'pull',
'modern',
'enviou',
'yile',
'samsung',
'expand',
'emang',
'diffici',
'georgi',
'kebun',
'energi',
'unlimit',
'wail',
'homi',
'attach',
'ingat',
'brag',
'netfilx',
'kareem',
'taxi',
'nd',
'tato',
'paypoint',
'jazz',
'put',
'along',
'x',
'michel',
'esp',
'quad',
'wwat',
'lukri',
'thebestangkapuso',
'goodwil',
'delight',
'r',
'agonis',
'rapist',
'polici',
'john',
'forgiv',
'jest',
'beatport',
'niend',
'yep',
'avi',
'wakeup',
'earthl',
'nba',
'woot',
'jlo',
'sona',
'flip',
'lang',
'act',
'clara',
'earth',
'sleeep',
'mileston',
'nostalgia',
'trop',
'sheeran',
'lilsisbro',
'ws',
'shoulder',
'colleg',
'fasgadah',
'troy',
'yok',
"women'",
'ducktail',
'guinea',
'gim',
'tbc',
'boo',
'parti',
'train',
'handsom',
'arabia',
'misunderstood',
'fanboy',
'stalker',
'brilliant',
'deennya',
'caramoan',
'vex',
'damn',
'otwolgrandtrail',
'spirit',
'debat',
'ray',
'friskyfiday',
'irrespons',
'โ',
'logic',
'depend',
'weltum',
'pretzel',
'steam',
'wowww',
'cranium',
'ack',
'safaa',
'weh',
'sublimin',
'tonn',
'contempl',
'servicewithasmil',
'hopetowin',
'thousand',
'account',
'carriageway',
'fair',
'kbye',
'stop',
'fav',
'๐',
'thigh',
'sb',
'hitter',
'eintlik',
'๐',
'izzi',
'tumblr',
'. .',
'allah',
'cheek',
'pay',
'dg',
'yar',
'higher',
'cotton',
'trespass',
'step',
'ryu',
'potassium',
'๐ฟ',
'hide',
'mybrainneedstoshutoff',
'physiotherapi',
'dur',
'cintiq',
'aplomb',
'vin',
'hyperbulli',
'papa',
'elgin',
'bracelet',
'nw',
'tita',
'opt-out',
'cereal',
"d'd",
'weekli',
'rubbish',
'pretend',
'diz',
'sunnyday',
'journey',
'rt',
'spn',
'penyfan',
'south',
'pout',
'religion',
'mixtur',
'yayyy',
'gd',
'tla',
'lesson',
'naruhina',
'tight',
'kikkomansabor',
'becom',
'close',
'ireland',
'heat',
'banda',
'rmtour',
'ejayst',
'/:',
'encanta',
'odoo',
'alright',
...]
= pd.DataFrame(['__PAD__', '__</e>__', '__UNK__'] + list(set(df_train['sentence'].apply(process_tweet).sum())), columns=['vocab']).reset_index()
df_vocab = df_vocab.set_index('vocab')
df_vocab df_vocab
index | |
---|---|
vocab | |
__PAD__ | 0 |
__</e>__ | 1 |
__UNK__ | 2 |
stand | 3 |
cross | 4 |
... | ... |
belov | 6999 |
arr | 7000 |
ng | 7001 |
art | 7002 |
jule | 7003 |
7004 rows ร 1 columns
def get_vocab(df, text_col, cb_process, pad_token='__PAD__', eof_token='__</e>__', unknown_token='__UNK__'):
= pd.DataFrame([pad_token, eof_token, unknown_token] + list(set(df[text_col].apply(cb_process).sum())),
df_vocab =['vocab']).reset_index()
columns= df_vocab.set_index('vocab')
df_vocab return df_vocab
= get_vocab(df_train, text_col='sentence', cb_process=process_tweet); df_vocab df_vocab
index | |
---|---|
vocab | |
__PAD__ | 0 |
__</e>__ | 1 |
__UNK__ | 2 |
stand | 3 |
cross | 4 |
... | ... |
belov | 6999 |
arr | 7000 |
ng | 7001 |
art | 7002 |
jule | 7003 |
7004 rows ร 1 columns
'sentence'] df[
0 Why is no one awake :(
1 @JediOMG Id like to think people know who I am :)
2 and another hour goes by :((
3 @beat_trees thank you Beatriz :)
4 @ashleelynda yes let's do it :)
...
9995 everytime i look at the clock its 3:02 and som...
9996 @ReddBlock1 @MadridGamesWeek Too far away :(
9997 It will all get better in time. :)
9998 Stats for the week have arrived. 2 new followe...
9999 3:33 am and all of my roommates are officially...
Name: sentence, Length: 10000, dtype: object
= "My name is Rahul Saraf :)"
msg = process_tweet(msg)
tokens tokens
['name', 'rahul', 'saraf', ':)']
= "__UNK__"
unknown_token = [ token if token in df_vocab.index else unknown_token for token in tokens]
processed_token processed_token
['name', '__UNK__', '__UNK__', ':)']
df_vocab.loc[processed_token]
index | |
---|---|
vocab | |
name | 5423 |
__UNK__ | 2 |
__UNK__ | 2 |
:) | 2877 |
'index'].tolist() df_vocab.loc[processed_token,
[5423, 2, 2, 2877]
def tweet2idx(tweet, cb_process, df_vocab=df_vocab, unknown_token='__UNK__'):
= cb_process(tweet)
tokens = [token if token in df_vocab.index else unknown_token for token in tokens]
processed_token return df_vocab.loc[processed_token, 'index'].tolist()
=process_tweet) tweet2idx(msg, cb_process
[5423, 2, 2, 2877]
Data Batching(Generator)
'sentiment'].value_counts()/len(df_train)).to_dict() (df_train[
{'negative': 0.5000888888888889, 'positive': 0.4999111111111111}
= True
shuffle = 4
batch_sz = 'sentence'
x_col = 'sentiment'
y_col = {'negative':0, 'positive':1}
class_dict
= (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
class_weights = df_train
df if shuffle : df = df.sample(frac=1.0)
= itertools.cycle(df.iterrows())
itr next(itr)
(2200, sentence I love my body, when it's fuller, and shit, ev...
sentiment positive
Name: 2200, dtype: object)
= '__PAD__'
pad_token = df_vocab.loc[pad_token, 'index']
pad_id pad_id
0
= (next(itr) for i in range(batch_sz)); list(batch) batch
[(9113, sentence @casslovesdrake yes cass I think you should :)
sentiment positive
Name: 9113, dtype: object),
(7538, sentence Someone talk to me I'm boreddd :(
sentiment negative
Name: 7538, dtype: object),
(2785, sentence Biodiversity, Taxonomic Infrastructure, Intern...
sentiment negative
Name: 2785, dtype: object),
(6754, sentence @TroutAmbush I'm at work :(\nI can show you on...
sentiment negative
Name: 6754, dtype: object)]
= (next(itr) for i in range(batch_sz))
batch = zip(*[(tweet2idx(entry[1][x_col], cb_process=process_tweet, df_vocab=df_vocab, unknown_token=unknown_token), entry[1][y_col]) for entry in batch])
X, y X, y
(([4224, 2550, 4207, 1530, 2877],
[931, 3201, 4390, 1039, 5509],
[702, 5549, 4180],
[6583,
1546,
5782,
5314,
2222,
6583,
3549,
4309,
6583,
4349,
11,
1671,
4254,
1329,
5525,
509]),
('positive', 'negative', 'positive', 'negative'))
pd.DataFrame(X)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4224 | 2550 | 4207 | 1530.0 | 2877.0 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | 931 | 3201 | 4390 | 1039.0 | 5509.0 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2 | 702 | 5549 | 4180 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
3 | 6583 | 1546 | 5782 | 5314.0 | 2222.0 | 6583.0 | 3549.0 | 4309.0 | 6583.0 | 4349.0 | 11.0 | 1671.0 | 4254.0 | 1329.0 | 5525.0 | 509.0 |
= np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32'); inputs inputs
DeviceArray([[4224, 2550, 4207, 1530, 2877, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[ 931, 3201, 4390, 1039, 5509, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[ 702, 5549, 4180, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[6583, 1546, 5782, 5314, 2222, 6583, 3549, 4309, 6583, 4349,
11, 1671, 4254, 1329, 5525, 509]], dtype=int32)
= np.array([class_dict[i] for i in y]); targets targets
DeviceArray([1, 0, 1, 0], dtype=int32)
= np.array([class_weights[i] for i in y]); weights weights
DeviceArray([1., 1., 1., 1.], dtype=float32)
# We assume df_vocab always have an index col. Perhaps we can parameterize it as well.
def text2idx(text, cb_process, df_vocab=df_vocab, unknown_token='__UNK__', vocab_idx_col='index'):
= cb_process(text)
tokens = [token if token in df_vocab.index else unknown_token for token in tokens]
processed_token return df_vocab.loc[processed_token, vocab_idx_col].tolist()
def process_batch(batch, func_text2idx, pad_id, class_dict, class_weights):
= zip(*[(func_text2idx(entry[1][x_col]), entry[1][y_col]) for entry in batch])
X, y = np.array(pd.DataFrame(X).fillna(pad_id), dtype='int32')
inputs = np.array([class_dict[i] for i in y])
targets = np.array([class_weights[i] for i in y])
weights return inputs, targets, weights
def data_generator(df, batch_sz, df_vocab, x_col, y_col, class_dict, class_weights, cb_process,
='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index',
unknown_token=False, loop=True, stop=False,
shuffle
):= lambda text: text2idx(text, cb_process, df_vocab=df_vocab, unknown_token=unknown_token, vocab_idx_col=vocab_idx_col)
func_text2idx = df_vocab.loc[pad_token, vocab_idx_col]
pad_id # unk_id = df_vocab.loc[unknown_token, 'index']
# eof_id = df_vocab.loc[eof_token, 'index']
while not stop:
= 0 # Resets index to zero on data exhaustion
index if shuffle : df = df.sample(frac=1.0) # Shuffles data - only relevant for train not eval tasks
= itertools.cycle(df.iterrows()) # Only purpose of cycle here is to handle last batch case when elements of dataset has been exhausted
itr while index <= len(df):
= (next(itr) for i in range(batch_sz))
batch += batch_sz
index yield process_batch(batch, func_text2idx, pad_id, class_dict, class_weights)
if loop: continue
else: break
= data_generator(df_train[:10].copy(), 3, df_vocab,
g ='sentence',
x_col='sentiment',
y_col={'negative':0, 'positive':1},
class_dict={'negative':1.0, 'positive':1.0},
class_weights=process_tweet
cb_process
)next(g)
(DeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32),
DeviceArray([1, 0, 1], dtype=int32),
DeviceArray([1., 1., 1.], dtype=float32))
= 0
count while count<6:
= next(g)
batch +=1
count print(batch)
(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
[1222, 1148, 2733, 34, 4180, 1148, 3552, 0, 0, 0],
[5441, 4763, 2915, 5453, 2877, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5962, 6958, 4779, 6958, 3651, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5651, 5208, 4343, 718, 2983, 2800, 266, 3749, 1961, 1306,
4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2836, 1148, 1825, 3520, 6388, 2877, 87, 2877],
[6347, 3211, 1899, 4447, 59, 2877, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([1, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32), DeviceArray([1, 0, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[2311, 6480, 6563, 2915, 4879, 5819, 4236, 4670, 2215, 5509],
[1222, 1148, 2733, 34, 4180, 1148, 3552, 0, 0, 0],
[5441, 4763, 2915, 5453, 2877, 0, 0, 0, 0, 0]], dtype=int32), DeviceArray([0, 1, 1], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
(DeviceArray([[6477, 2196, 3309, 5509, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5962, 6958, 4779, 6958, 3651, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[5651, 5208, 4343, 718, 2983, 2800, 266, 3749, 1961, 1306,
4879, 3988, 2596, 5509]], dtype=int32), DeviceArray([0, 1, 0], dtype=int32), DeviceArray([1., 1., 1.], dtype=float32))
Model Definition
=256
emb_dims= df_vocab.shape[0]; vocab_sz vocab_sz
7004
= tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims); embed_layer embed_layer
Embedding_7004_256
= data_generator(df_train[:10].copy(), 3, df_vocab,
g ='sentence',
x_col='sentiment',
y_col={'negative':0, 'positive':1},
class_dict={'negative':1.0, 'positive':1.0},
class_weights=process_tweet
cb_process
)= next(g);inputs, targets, weights
inputs, targets, weights
inputs
DeviceArray([[6347, 3211, 1899, 4447, 59, 2877, 0, 0, 0, 0],
[2281, 5509, 0, 0, 0, 0, 0, 0, 0, 0],
[1026, 5522, 2035, 1922, 4847, 3651, 4250, 2288, 5812, 239]], dtype=int32)
inputs.shape
(3, 10)
embed_layer.init(trax.shapes.signature(inputs)) embed_layer(inputs).shape
(3, 10, 256)
= tl.Mean(axis=1)
mean_layer mean_layer(embed_layer(inputs)).shape
(3, 256)
= mean_layer(embed_layer(inputs)); ml_o ml_o
DeviceArray([[ 2.79488266e-02, -1.45270219e-02, 4.40993309e-02,
8.53405427e-03, 1.76605564e-02, -3.46990190e-02,
2.33984273e-02, -1.22762835e-02, 3.24008404e-03,
2.61026900e-02, 4.38510887e-02, -1.34156886e-02,
2.90309023e-02, -1.02531705e-02, -1.12023270e-02,
-2.87333727e-02, -9.31228697e-03, 1.19030094e-02,
-4.86194436e-03, 2.03585587e-02, 2.08210200e-02,
2.36468110e-02, -1.37046902e-02, 5.95771149e-02,
-1.72263887e-02, -2.50128116e-02, 2.75060274e-02,
1.49012199e-02, -3.22965020e-03, 2.67289430e-02,
5.11599965e-02, -1.63981263e-02, 3.89814610e-03,
2.38245309e-04, -4.46219929e-02, -3.15294527e-02,
2.55568903e-02, -6.74413843e-03, 4.38911049e-03,
5.11370413e-02, -6.65329397e-04, -1.93921756e-02,
-4.15456630e-02, -4.39555645e-02, 8.59785639e-03,
-1.46773551e-02, -9.66112036e-03, -1.14421621e-02,
-9.40839574e-03, -2.26056613e-02, 3.71458828e-02,
5.98752610e-02, -3.57376747e-02, 3.96995666e-03,
2.67707705e-02, -3.15900822e-03, 2.09550802e-02,
-1.93517022e-02, 1.05023105e-02, -1.86009407e-02,
2.98489183e-02, -4.90169562e-02, -1.11293094e-02,
-2.30411999e-02, -9.55593865e-03, -2.26994953e-03,
-1.48823233e-02, 2.05163322e-02, 5.47512360e-02,
-3.99000458e-02, 8.83504096e-03, 2.84845773e-02,
-5.04426919e-02, -3.60746942e-02, -9.55942087e-04,
-2.53703184e-02, 3.02115735e-02, -1.57081429e-02,
-3.74010913e-02, 1.76684689e-02, 2.65503470e-02,
3.19637395e-02, -1.10823708e-02, 1.11593930e-02,
5.25960736e-02, -8.53511132e-03, -2.15488132e-02,
3.24181207e-02, 4.10275348e-02, -1.24717234e-02,
2.00908165e-02, -1.86071694e-02, 6.19316101e-02,
8.23778193e-03, 4.11699824e-02, -5.55267148e-02,
2.96739675e-02, -4.94132079e-02, 1.98996719e-02,
-8.68158694e-03, -8.26465338e-03, -1.16721755e-02,
1.24795800e-02, -2.28749793e-02, -1.13831200e-02,
1.38960099e-02, -6.09649681e-02, 1.92927364e-02,
1.27717704e-02, 1.81246158e-02, 2.85487180e-03,
2.23373864e-02, -1.62030017e-04, 3.29689048e-02,
5.63657284e-02, -2.56947521e-02, -2.07725521e-02,
-1.73829421e-02, -8.39774217e-03, -6.28772005e-02,
-2.57897824e-02, 3.51984426e-02, 3.11603304e-02,
1.96733624e-02, 2.02013683e-02, -3.11372615e-02,
-4.72403364e-03, 2.01003700e-02, 1.16788512e-02,
-2.10389961e-02, -2.60720439e-02, 3.46273743e-02,
-8.92939061e-05, 2.81158667e-02, 3.72311915e-03,
-1.57815162e-02, -3.94454040e-02, -6.79905352e-04,
-1.24713592e-02, 3.68128233e-02, -2.74460316e-02,
-2.67537721e-02, 5.98629229e-02, -3.95256244e-02,
8.96685582e-04, -2.48770304e-02, -2.35572760e-03,
3.49813960e-02, -2.19532456e-02, -1.54458312e-02,
-2.51260437e-02, 2.38183625e-02, -1.94952022e-02,
4.12691897e-03, -3.46308365e-03, -3.33104320e-02,
-3.65591981e-02, -3.36227827e-02, 1.73776653e-02,
-4.78204433e-03, -1.89488437e-02, 2.72095297e-02,
2.02283915e-02, -3.93495895e-02, 2.03697141e-02,
3.37545536e-02, -4.36529890e-02, 2.83900145e-02,
-8.81456956e-03, 5.76226832e-03, -2.39776410e-02,
5.01187928e-02, -1.67681854e-02, 4.58664112e-02,
-5.49267307e-02, 5.44885360e-02, 2.31288821e-02,
1.85393970e-02, -3.09647825e-02, -4.17682528e-03,
-1.41973272e-02, -1.80879440e-02, 3.17270006e-03,
-4.11324278e-02, 1.53492196e-02, -1.60732698e-02,
3.17174452e-03, 1.02486191e-02, -6.71328604e-03,
2.56019589e-02, -3.71041335e-02, 2.63536163e-02,
2.43177768e-02, -2.19865423e-02, 1.95260886e-02,
1.24767972e-02, 1.73309650e-02, -5.32603040e-02,
-4.32931334e-02, -2.02363096e-02, -6.08797092e-03,
-2.26413738e-02, -3.22343297e-02, -5.00520468e-02,
2.97480021e-02, 3.54830138e-02, 9.21159226e-04,
-4.26780358e-02, 9.97830089e-03, -6.02463372e-02,
1.70748364e-02, -4.51036617e-02, 5.73712066e-02,
5.09973131e-02, -2.35319715e-02, 3.43024591e-03,
2.99761724e-02, 1.53054269e-02, 2.92918663e-02,
4.02006172e-02, -3.04911137e-02, -4.19789506e-03,
1.43923163e-02, 2.02765651e-02, 4.24498804e-02,
-2.35497188e-02, -2.46628560e-02, 4.21048934e-03,
1.39114326e-02, -1.01034585e-02, -2.35804580e-02,
-5.40269092e-02, -5.64321987e-02, 3.26057486e-02,
3.14488187e-02, 1.99555140e-02, 1.84255000e-02,
-1.17764436e-02, 7.91816972e-03, 1.02443406e-02,
-1.08196680e-02, -3.24096046e-02, -1.28721390e-02,
-9.81023535e-04, -1.97989121e-02, -2.10534465e-02,
1.80622116e-02, 4.31514047e-02, 5.77958301e-03,
-2.23643947e-02, 1.69666093e-02, -8.90544057e-03,
-6.02399595e-02, 4.38689850e-02, 1.03713451e-02,
1.28128221e-02],
[ 3.38993296e-02, 1.35733951e-02, 6.66464493e-02,
2.42346879e-02, 6.26562163e-02, -5.03408797e-02,
3.33895311e-02, -5.17918952e-02, 4.52371016e-02,
4.09152135e-02, 5.93941472e-02, 3.95135675e-03,
4.75698374e-02, 2.36201100e-02, -1.38653982e-02,
-3.50236297e-02, 1.83179639e-02, 8.88660774e-02,
-7.64574809e-03, 4.41893004e-02, 1.90847274e-02,
6.45669326e-02, 1.85167082e-02, 2.84678321e-02,
-2.65649464e-02, -8.91264305e-02, 3.99326384e-02,
4.97101732e-02, 1.65231507e-02, 6.03857115e-02,
5.67554049e-02, -6.60985932e-02, -2.66933031e-02,
-1.59043223e-02, -4.03059348e-02, -9.00496766e-02,
4.25944850e-02, -8.74052662e-03, -1.57386363e-02,
9.68783870e-02, 7.67241279e-03, 7.77393405e-04,
-6.49266019e-02, -5.11372164e-02, -2.04032590e-03,
-4.40331548e-02, -6.01055585e-02, -6.31405488e-02,
-6.24569133e-02, -5.85328154e-02, 7.07470477e-02,
8.17954689e-02, -7.81841502e-02, 5.07697463e-02,
1.92059821e-03, 1.69407763e-02, 2.37641782e-02,
-3.53052393e-02, 2.04827357e-02, -1.37658538e-02,
9.69784185e-02, -8.19189474e-02, -4.75043990e-02,
1.99681353e-02, -8.09543729e-02, 1.88667532e-02,
-6.31931052e-02, 3.94086055e-02, 4.83753942e-02,
-7.28371665e-02, -5.50396442e-02, 7.91042596e-02,
-4.60295789e-02, -5.08559234e-02, 2.92809750e-03,
-7.36618638e-02, 5.63151427e-02, -5.38185835e-02,
-6.12046383e-02, 6.76287487e-02, 5.48021495e-02,
7.16748163e-02, -5.46516180e-02, 8.72445628e-02,
7.69111887e-02, -3.30434181e-02, -6.08568788e-02,
6.71918765e-02, 3.20970826e-02, -2.44096198e-04,
4.92355786e-02, -1.22685982e-02, 6.01240955e-02,
-8.08215048e-03, 6.12488985e-02, -7.07898661e-02,
2.49854894e-03, -7.16750547e-02, -1.98491234e-02,
1.55626601e-02, -1.58036947e-02, -8.61070529e-02,
4.21486534e-02, -4.98561328e-03, 2.41489266e-03,
8.83591920e-02, -9.06186104e-02, 1.35235013e-02,
1.32905021e-02, 3.04432455e-02, 3.98928188e-02,
7.79912621e-02, 2.56175045e-02, 4.51246314e-02,
8.10684338e-02, 8.68480373e-03, -6.25829846e-02,
3.14559937e-02, -3.54814455e-02, -9.09190997e-02,
-3.19108739e-02, 5.67414127e-02, 4.53827828e-02,
4.47570793e-02, 7.71191940e-02, -3.83882076e-02,
-2.63503771e-02, 4.22267616e-02, -3.58217470e-02,
-3.12248915e-02, -4.34958600e-02, 4.00263220e-02,
-4.52244841e-03, 7.26998970e-02, 4.99958470e-02,
-7.20170811e-02, -7.27941990e-02, -3.07344142e-02,
-9.12117660e-02, 4.37249653e-02, -2.96394955e-02,
-8.28814730e-02, 9.10605490e-02, -5.95877841e-02,
2.90400572e-02, 1.20591081e-03, 3.32733952e-02,
-1.69485584e-02, -5.68195172e-02, -4.57010306e-02,
-1.22073814e-02, 3.03008687e-02, -4.01868485e-02,
-5.95153682e-02, 1.27641959e-02, -6.82596192e-02,
-2.72259694e-02, -4.46438715e-02, 5.05713467e-03,
-6.51221117e-03, -5.46841882e-02, 6.58601895e-02,
4.11701612e-02, -1.97417717e-02, 4.66704480e-02,
1.64697226e-02, -3.19867991e-02, 6.52857572e-02,
-3.86448540e-02, 3.57317403e-02, -6.92952424e-02,
7.34883100e-02, 1.75105873e-03, 6.97340593e-02,
-7.71872625e-02, 7.05697387e-02, 8.34564567e-02,
2.78792586e-02, -4.48744111e-02, -8.32851157e-02,
-2.21689213e-02, -5.00433743e-02, 3.84568796e-02,
-7.70757943e-02, 6.88394234e-02, 1.92586612e-02,
-2.58953515e-02, 6.19755462e-02, 3.15835588e-02,
2.51717754e-02, -4.82867360e-02, 4.04246673e-02,
2.39828061e-02, -3.71591412e-02, -3.68225835e-02,
4.88160513e-02, 5.52501865e-02, -5.53951561e-02,
-9.27906577e-03, -2.76789647e-02, 3.80942784e-02,
-7.84017965e-02, -5.56430332e-02, -6.55776560e-02,
8.59722421e-02, 7.21482262e-02, 9.23464913e-03,
-6.77740499e-02, -6.00923551e-03, -8.99679214e-02,
-2.39077192e-02, -1.64173227e-02, 8.12381506e-02,
4.63083312e-02, -8.29748958e-02, -5.96962348e-02,
4.78492901e-02, 2.33227890e-02, 7.23729581e-02,
6.62301108e-02, -4.31413725e-02, 4.76113334e-03,
4.34371941e-02, 4.70355712e-02, 5.13664559e-02,
-3.06457970e-02, -1.48225399e-02, -2.08158791e-03,
8.98614805e-03, 3.12850773e-02, -4.92903404e-02,
-6.79363087e-02, -5.74367456e-02, 4.95679630e-03,
1.30830621e-02, 7.03895241e-02, 6.73718378e-02,
-4.37890477e-02, 1.28442496e-02, -3.94231221e-03,
-4.26534340e-02, -7.27512836e-02, 1.74762588e-02,
6.07303455e-02, -3.55013050e-02, -2.01774128e-02,
-1.53364437e-02, 6.09087944e-02, -9.23513435e-03,
-2.86345761e-02, 7.97292311e-03, 2.85661016e-02,
-5.06623983e-02, 6.88529015e-02, 1.98510122e-02,
5.23267277e-02],
[-1.39780194e-02, 5.57200797e-03, 6.62390841e-03,
1.40526835e-02, 3.70232575e-03, 2.78138611e-02,
-1.17395986e-02, -1.75640993e-02, -1.08815432e-02,
7.73417437e-03, 1.12588918e-02, 6.00743247e-03,
1.92421768e-02, 8.56465939e-03, 1.58472434e-02,
1.13404458e-02, 1.98833868e-02, 9.56519134e-03,
-3.23476531e-02, -3.88078555e-03, 3.16261761e-02,
-6.10976527e-03, -3.43191111e-03, 1.11242314e-03,
-2.01413725e-02, -6.99900975e-03, 4.15012203e-02,
-2.71617379e-02, 2.05251891e-02, 8.74847639e-03,
2.19325479e-02, -1.16352653e-02, -1.16728516e-02,
-4.01860196e-03, 3.24216075e-02, 1.86130144e-02,
-3.73210534e-02, 1.97703447e-02, 1.81674343e-02,
-1.08840652e-02, 1.04624722e-02, -5.08846482e-03,
1.17169162e-02, 7.22383568e-03, -1.81986447e-02,
-3.80121847e-03, -1.47839787e-03, -2.76529342e-02,
-2.40432825e-02, 3.53930630e-02, 5.81764616e-03,
3.42540490e-03, -2.56362092e-02, 1.92029607e-02,
1.59262177e-02, -2.76176818e-02, -2.91440934e-02,
-8.92254896e-03, -1.53544936e-02, 4.30388004e-03,
-4.02754918e-03, 2.11348645e-02, 1.42662963e-02,
3.30992378e-02, 8.72956123e-03, 1.22786937e-02,
1.79827865e-02, -1.13253882e-02, -2.08487618e-03,
-1.95157751e-02, 2.13124137e-02, 8.57930645e-05,
-3.04783266e-02, 1.88843664e-02, -7.84671586e-03,
2.63085775e-03, -1.90286618e-02, 2.09811907e-02,
-3.05278855e-03, 1.59780458e-02, 1.62816681e-02,
1.79238077e-02, -2.98823416e-03, -8.01594462e-03,
-5.62981963e-02, -2.12363396e-02, 3.74881253e-02,
-1.03619248e-02, -8.58019944e-03, 1.30141713e-03,
1.26137082e-02, -1.84016582e-02, -1.00487191e-02,
3.35148461e-02, 1.32365748e-02, -3.28364898e-03,
2.72394177e-02, -1.77865010e-02, 2.78097894e-02,
2.22075824e-02, -1.26399687e-02, 2.73234043e-02,
-1.28249675e-02, 5.34611230e-04, -2.46077161e-02,
-9.22045764e-03, -1.25545887e-02, -3.82161024e-03,
-7.73401558e-03, 2.34106183e-02, 3.65191186e-03,
-4.31962088e-02, 1.92353670e-02, 2.88315117e-02,
-3.85706387e-02, -1.47016747e-02, -8.37306213e-03,
-9.17184539e-03, 3.19922721e-04, -1.68359285e-04,
-1.44446613e-02, 2.08489094e-02, 2.36450788e-02,
3.02283615e-02, -7.01557985e-03, -9.20639187e-03,
7.84213096e-03, 1.18821133e-02, -4.94026951e-03,
-5.78507408e-03, -3.33779231e-02, -2.48717819e-03,
-8.41716677e-03, -2.90851900e-03, -2.25243680e-02,
-3.70287336e-03, -6.62418408e-03, -2.31272709e-02,
-5.77035593e-03, 1.99090093e-02, 2.24502292e-02,
-1.09872520e-02, 7.33999163e-03, 1.29230488e-02,
1.06960991e-02, 5.60026569e-03, 2.85538356e-03,
2.11979337e-02, -1.93031132e-02, -1.03727486e-02,
2.32398100e-02, 1.67589486e-02, 2.64838431e-02,
2.00437512e-02, -3.74251977e-02, -2.31030621e-02,
1.93875786e-02, -4.20942418e-02, 6.15880266e-03,
2.85475887e-02, -7.63351610e-03, -1.42041994e-02,
1.43085094e-02, -1.92951355e-02, 1.49155734e-03,
-1.53728919e-02, 8.10561515e-03, 1.31099194e-03,
7.46731600e-03, 1.77144911e-02, 2.12036073e-02,
1.17687760e-02, 1.14847599e-02, -1.77644342e-02,
3.32744583e-03, 2.84356475e-02, -4.24198173e-02,
3.08298413e-02, 8.84121750e-03, -2.20039161e-03,
-2.52739619e-02, -1.70355421e-02, 4.11101431e-03,
-2.13460531e-02, 1.71343666e-02, -7.55970553e-03,
2.54368167e-02, -2.33681849e-03, 1.42502338e-02,
-1.56443461e-03, -7.66406721e-03, 5.78241143e-03,
-5.79094980e-03, 1.74651798e-02, 8.98183230e-03,
-9.72723216e-03, -4.81641386e-03, 1.29084131e-02,
-9.91731323e-03, 1.88898686e-02, 1.09237237e-02,
-1.23092616e-02, -3.95849172e-04, 3.72255291e-03,
-6.40551466e-03, 5.43562472e-02, -7.50610465e-03,
-6.19069487e-03, 4.09565680e-02, 7.31755805e-04,
1.10384598e-02, -3.49297114e-02, 1.42621633e-03,
5.26248012e-03, 4.01750579e-02, -2.51674652e-02,
2.22904701e-02, 8.94705113e-03, 5.59239602e-03,
1.48024457e-02, 2.17557754e-02, 4.39300761e-02,
-3.35346423e-02, 2.03318152e-04, -1.16438363e-02,
-6.69360999e-03, 5.17346896e-03, -9.16168839e-03,
1.45597383e-04, -1.74349006e-02, -1.51037879e-03,
-9.20751132e-04, -3.40526849e-02, 1.40658682e-02,
1.40288277e-02, -8.62152595e-03, 3.18979137e-02,
-1.24759851e-02, -5.77731524e-03, 1.73657387e-02,
-2.05909293e-02, -4.04520743e-02, 4.15800838e-03,
2.34187152e-02, -1.31293088e-02, -6.62580831e-03,
-2.49841791e-02, 1.33990617e-02, -4.37055156e-03,
4.77700904e-02, 5.86690032e-04, -5.15651703e-03,
-2.12248955e-02, -1.09182587e-02, -9.91707761e-03,
1.73570886e-02]], dtype=float32)
= tl.Dense(n_units=2)
dense_layer
dense_layer.init(trax.shapes.signature(ml_o)) dense_layer(ml_o)
DeviceArray([[-0.09669925, -0.00369832],
[-0.14038305, -0.01882405],
[-0.08846572, -0.01865664]], dtype=float32)
def classifier(output_dims, vocab_sz, emb_dims=256, mode='train'):
= tl.Embedding(vocab_size=vocab_sz, d_feature=emb_dims)
emb_layer = tl.Mean(axis=1)
mean_layer = tl.Dense(n_units=output_dims)
dense_output_layer = tl.LogSoftmax()
log_softmax_layer
= tl.Serial(
model
emb_layer,
mean_layer,
dense_output_layer,
log_softmax_layer
)
return model
2, df_vocab.shape[0]) classifier(
Serial[
Embedding_7004_256
Mean
Dense_2
LogSoftmax
]
Define Train Eval Task
def get_train_eval_tasks(df_train, df_valid, df_vocab, x_col, y_col, class_dict, class_weights, cb_process, batch_sz=16,
='__UNK__', pad_token='__PAD__', eof_token="__</e>__", vocab_idx_col='index'):
unknown_token271)
rnd.seed(
= training.TrainTask(
train_task =data_generator(df_train, batch_sz, df_vocab,
labeled_data=x_col,
x_col=y_col,
y_col=class_dict,
class_dict=class_weights,
class_weights=cb_process,
cb_process=unknown_token,
unknown_token=pad_token,
pad_token=eof_token,
eof_token=vocab_idx_col,
vocab_idx_col=True,
shuffle=True,
loop=False
stop
),= tl.WeightedCategoryCrossEntropy(),
loss_layer = trax.optimizers.Adam(0.01),
optimizer =10)
n_steps_per_checkpoint
= training.EvalTask(
eval_task =data_generator(df_valid, batch_sz, df_vocab,
labeled_data=x_col,
x_col=y_col,
y_col=class_dict,
class_dict=class_weights,
class_weights=cb_process,
cb_process=unknown_token,
unknown_token=pad_token,
pad_token=eof_token,
eof_token=vocab_idx_col,
vocab_idx_col=False,
shuffle=False,
loop=False
stop
),=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()]
metrics
)return train_task, eval_task
len(class_dict)
2
Model Training
def train_model(classifier, train_task, eval_task, n_steps, output_dir='.'):
1234)
rnd.seed(= training.Loop(classifier, train_task, eval_tasks=eval_task, output_dir=output_dir, random_seed=112)
training_loop =n_steps)
training_loop.run(n_stepsreturn training_loop
= 'sentence'
x_col = 'sentiment'
y_col = {'negative':0, 'positive':1}
class_dict = (df_train['sentiment'].value_counts()/len(df_train)).to_dict()
class_weights = {'negative':1.0, 'positive':1.0}
class_weights =16
batch_sz='__UNK__'
unknown_token='__PAD__'
pad_token="__</e>__"
eof_token= process_tweet
cb_process = len(class_dict)
output_dims = 256
emb_dims ="."
output_dir= train_test_split(df, stratify=df[y_col])
df_model, df_blind = train_test_split(df_model, stratify=df_model[y_col])
df_train, df_valid = get_vocab(df_train, text_col=x_col, cb_process=cb_process, pad_token=pad_token, eof_token=eof_token, unknown_token=unknown_token)
df_vocab = df_vocab.shape[0]
vocab_sz ='index'
vocab_idx_col
cb_process= get_train_eval_tasks(
train_task, eval_task
df_train,
df_valid,
df_vocab,
x_col, y_col, class_dict, class_weights,
cb_process, batch_sz,
unknown_token, pad_token, eof_token,='index') ; train_task, eval_task
vocab_idx_col
= classifier(output_dims, vocab_sz, emb_dims=emb_dims, mode='train')
model = train_model(model, train_task, eval_task, n_steps=40, output_dir=output_dir) training_loop
/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
Step 1: Total number of trainable weights: 1235202
Step 1: Ran 1 train steps in 0.94 secs
Step 1: train WeightedCategoryCrossEntropy | 0.69573909
Step 1: eval WeightedCategoryCrossEntropy | 0.63845330
Step 1: eval WeightedCategoryAccuracy | 0.81250000
Step 10: Ran 9 train steps in 3.55 secs
Step 10: train WeightedCategoryCrossEntropy | 0.65314281
Step 10: eval WeightedCategoryCrossEntropy | 0.60192943
Step 10: eval WeightedCategoryAccuracy | 0.87500000
Step 20: Ran 10 train steps in 2.19 secs
Step 20: train WeightedCategoryCrossEntropy | 0.51353288
Step 20: eval WeightedCategoryCrossEntropy | 0.41577297
Step 20: eval WeightedCategoryAccuracy | 1.00000000
Step 30: Ran 10 train steps in 1.23 secs
Step 30: train WeightedCategoryCrossEntropy | 0.26818269
Step 30: eval WeightedCategoryCrossEntropy | 0.14980227
Step 30: eval WeightedCategoryAccuracy | 1.00000000
Step 40: Ran 10 train steps in 1.30 secs
Step 40: train WeightedCategoryCrossEntropy | 0.15570463
Step 40: eval WeightedCategoryCrossEntropy | 0.08368409
Step 40: eval WeightedCategoryAccuracy | 1.00000000
Model Evaluation
= data_generator(df_blind, batch_sz, df_vocab,
blind_gen =x_col,
x_col=y_col,
y_col=class_dict,
class_dict=class_weights,
class_weights=cb_process,
cb_process=unknown_token,
unknown_token=pad_token,
pad_token=eof_token,
eof_token=vocab_idx_col,
vocab_idx_col=False,
shuffle=False,
loop=False
stop
)
= next(blind_gen)
batch = batch
blind_inputs, blind_targets, blind_weights training_loop.eval_model(blind_inputs)
DeviceArray([[-6.7398906e-02, -2.7306366e+00],
[-6.8514347e-03, -4.9867296e+00],
[-3.4929681e+00, -3.0882478e-02],
[-5.0519943e-02, -3.0105407e+00],
[-2.9234269e+00, -5.5247664e-02],
[-3.8873537e+00, -2.0712495e-02],
[-5.2998271e+00, -5.0048828e-03],
[-3.5535395e+00, -2.9040813e-02],
[-4.6031094e+00, -1.0071278e-02],
[-2.1165485e+00, -1.2834114e-01],
[-4.1328192e-02, -3.2068014e+00],
[-2.3841858e-05, -1.0641651e+01],
[-2.1345377e-02, -3.8575726e+00],
[-2.6583314e-02, -3.6407375e+00],
[-7.7986407e+00, -4.1031837e-04],
[-1.1112690e-02, -4.5052152e+00]], dtype=float32)