import sys
= 'google.colab' in sys.modules
IN_COLAB IN_COLAB
True
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 637.9/637.9 KB 10.2 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.8/5.8 MB 71.7 MB/s eta 0:00:00
import os
import nltk
import re
import string
import itertools
import requests
import pandas as pd
import trax
import trax.fastmath.numpy as np
import random as rnd
from trax import layers as tl
from trax import shapes
from trax.supervised import training
from sklearn.model_selection import train_test_split
from fastcore.all import *
import io
import string
url = "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
resp = requests.get(url)
s = io.BytesIO(resp.content)
s.seek(0)
wrapper = io.TextIOWrapper(s)
lines = wrapper.readlines()
len(lines)
124456
['PERSONAL USE ONLY, AND (2) ARE NOT DISTRIBUTED OR USED\n',
'COMMERCIALLY. PROHIBITED COMMERCIAL DISTRIBUTION INCLUDES BY ANY\n',
'SERVICE THAT CHARGES FOR DOWNLOAD TIME OR FOR MEMBERSHIP.>>\n',
'\n',
'\n',
'\n',
'End of this Etext of The Complete Works of William Shakespeare\n',
'\n',
'\n',
'\n']
print(f"ord('a'): {ord('a')}")
print(f"ord('b'): {ord('b')}")
print(f"ord('c'): {ord('c')}")
print(f"ord(' '): {ord(' ')}")
print(f"ord('x'): {ord('x')}")
print(f"ord('y'): {ord('y')}")
print(f"ord('z'): {ord('z')}")
print(f"ord('1'): {ord('1')}")
print(f"ord('2'): {ord('2')}")
print(f"ord('3'): {ord('3')}")
ord('a'): 97
ord('b'): 98
ord('c'): 99
ord(' '): 32
ord('x'): 120
ord('y'): 121
ord('z'): 122
ord('1'): 49
ord('2'): 50
ord('3'): 51
for i in (string.ascii_lowercase + string.digits + string.punctuation):
print(f"ord({i}): {ord(i)}")
ord(a): 97
ord(b): 98
ord(c): 99
ord(d): 100
ord(e): 101
ord(f): 102
ord(g): 103
ord(h): 104
ord(i): 105
ord(j): 106
ord(k): 107
ord(l): 108
ord(m): 109
ord(n): 110
ord(o): 111
ord(p): 112
ord(q): 113
ord(r): 114
ord(s): 115
ord(t): 116
ord(u): 117
ord(v): 118
ord(w): 119
ord(x): 120
ord(y): 121
ord(z): 122
ord(0): 48
ord(1): 49
ord(2): 50
ord(3): 51
ord(4): 52
ord(5): 53
ord(6): 54
ord(7): 55
ord(8): 56
ord(9): 57
ord(!): 33
ord("): 34
ord(#): 35
ord($): 36
ord(%): 37
ord(&): 38
ord('): 39
ord((): 40
ord()): 41
ord(*): 42
ord(+): 43
ord(,): 44
ord(-): 45
ord(.): 46
ord(/): 47
ord(:): 58
ord(;): 59
ord(<): 60
ord(=): 61
ord(>): 62
ord(?): 63
ord(@): 64
ord([): 91
ord(\): 92
ord(]): 93
ord(^): 94
ord(_): 95
ord(`): 96
ord({): 123
ord(|): 124
ord(}): 125
ord(~): 126
def line2tensor(line, EOS_int=1):
tensor = [ord(c) for c in line]
tensor.append(EOS_int)
return tensor
lines[0], line2tensor(lines[0])
('this is the 100th etext file presented by project gutenberg, and\n',
[116,
104,
105,
115,
32,
105,
115,
32,
116,
104,
101,
32,
49,
48,
48,
116,
104,
32,
101,
116,
101,
120,
116,
32,
102,
105,
108,
101,
32,
112,
114,
101,
115,
101,
110,
116,
101,
100,
32,
98,
121,
32,
112,
114,
111,
106,
101,
99,
116,
32,
103,
117,
116,
101,
110,
98,
101,
114,
103,
44,
32,
97,
110,
100,
10,
1])
# def data_generator(data, batch_sz, cb_process, max_len, shuffle=False):
# while True:
# index = 0 # Resets index to zero on data exhaustion
# if shuffle : df = df.sample(frac=1.0) # Shuffles data - only relevant for train not eval tasks
# itr = itertools.cycle(df.iterrows()) # Only purpose of cycle here is to handle last batch case when elements of dataset has been exhausted
# while index <= len(df):
# batch = (next(itr) for i in range(batch_sz))
# index += batch_sz
# yield process_batch(batch, func_text2idx, pad_id, class_dict, class_weights)
# if loop: continue
# else: break
def batched(iterable, n):
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1: raise ValueError('n must be at least one')
it = iter(iterable)
while (batch := tuple(itertools.islice(it, n))):
yield batch
batched('ABCDEFG', 3)
def text_data_generator(data, batch_sz, max_len, shuffle=True):
n_lines = len(data); n_lines
idx_lines = list(range(n_lines))
if shuffle: rnd.shuffle(idx_lines); idx_lines[:10]
line_gen = (data[i] for i in idx_lines)
filter_line_gen = filter(lambda line: len(line) < max_len, line_gen)
yield from batched(filter_line_gen, batch_sz)
def data_generator(data, batch_sz, max_len, cb_process=line2tensor, shuffle=True):
num_batch_arr = None
mask_batch_arr = None
g = text_data_generator(data, batch_sz, max_len, shuffle=shuffle)
# if infinite: g = itertools.cycle(g)
for batch in g:
num_batch_arr = np.array([cb_process(li)+[0]*(max_len - (len(li)+1)) for li in batch]) # (len(li)+1) because line2tensor adds an EOS_int
mask_batch_arr = np.array([[1]*(len(li)+1)+[0]*(max_len - (len(li)+1)) for li in batch])
yield num_batch_arr, num_batch_arr, mask_batch_arr
data = data_model
shuffle = True
max_len = 10
data[:10]
idx_lines_tmp = list(range(10)); idx_lines_tmp
if shuffle: rnd.shuffle(idx_lines_tmp)
data[:10], idx_lines_tmp
(['this is the 100th etext file presented by project gutenberg, and\n',
'is presented in cooperation with world library, inc., from their\n',
'library of the future and shakespeare cdroms. project gutenberg\n',
'often releases etexts that are not placed in the public domain!!\n',
'\n',
'shakespeare\n',
'\n',
'*this etext has certain copyright implications you should read!*\n',
'\n',
'<<this electronic version of the complete works of william\n'],
[4, 6, 0, 8, 9, 5, 7, 2, 3, 1])
[('this is the 100th etext file presented by project gutenberg, and\n',
'is presented in cooperation with world library, inc., from their\n',
'library of the future and shakespeare cdroms. project gutenberg\n',
'often releases etexts that are not placed in the public domain!!\n'),
('\n',
'shakespeare\n',
'\n',
'*this etext has certain copyright implications you should read!*\n'),
('\n', '<<this electronic version of the complete works of william\n')]
g = text_data_generator(data[:10].copy(), batch_sz=4, max_len=100, shuffle=True)
index=0
stop_id = 10
for batch in itertools.cycle(g):
print(batch)
index +=1
if index > stop_id: break
('library of the future and shakespeare cdroms. project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms. project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms. project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms. project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
[(DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)),
(DeviceArray([[111, 102, 116, 101, 110, 32, 114, 101, 108, 101, 97, 115,
101, 115, 32, 101, 116, 101, 120, 116, 115, 32, 116, 104,
97, 116, 32, 97, 114, 101, 32, 110, 111, 116, 32, 112,
108, 97, 99, 101, 100, 32, 105, 110, 32, 116, 104, 101,
32, 112, 117, 98, 108, 105, 99, 32, 100, 111, 109, 97,
105, 110, 33, 33, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[115, 104, 97, 107, 101, 115, 112, 101, 97, 114, 101, 10,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[111, 102, 116, 101, 110, 32, 114, 101, 108, 101, 97, 115,
101, 115, 32, 101, 116, 101, 120, 116, 115, 32, 116, 104,
97, 116, 32, 97, 114, 101, 32, 110, 111, 116, 32, 112,
108, 97, 99, 101, 100, 32, 105, 110, 32, 116, 104, 101,
32, 112, 117, 98, 108, 105, 99, 32, 100, 111, 109, 97,
105, 110, 33, 33, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[115, 104, 97, 107, 101, 115, 112, 101, 97, 114, 101, 10,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)),
(DeviceArray([[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 42, 116, 104, 105, 115, 32, 101, 116, 101, 120, 116, 32,
104, 97, 115, 32, 99, 101, 114, 116, 97, 105, 110, 32,
99, 111, 112, 121, 114, 105, 103, 104, 116, 32, 105, 109,
112, 108, 105, 99, 97, 116, 105, 111, 110, 115, 32, 121,
111, 117, 32, 115, 104, 111, 117, 108, 100, 32, 114, 101,
97, 100, 33, 42, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 42, 116, 104, 105, 115, 32, 101, 116, 101, 120, 116, 32,
104, 97, 115, 32, 99, 101, 114, 116, 97, 105, 110, 32,
99, 111, 112, 121, 114, 105, 103, 104, 116, 32, 105, 109,
112, 108, 105, 99, 97, 116, 105, 111, 110, 115, 32, 121,
111, 117, 32, 115, 104, 111, 117, 108, 100, 32, 114, 101,
97, 100, 33, 42, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)),
(DeviceArray([[ 60, 60, 116, 104, 105, 115, 32, 101, 108, 101, 99, 116,
114, 111, 110, 105, 99, 32, 118, 101, 114, 115, 105, 111,
110, 32, 111, 102, 32, 116, 104, 101, 32, 99, 111, 109,
112, 108, 101, 116, 101, 32, 119, 111, 114, 107, 115, 32,
111, 102, 32, 119, 105, 108, 108, 105, 97, 109, 10, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[ 60, 60, 116, 104, 105, 115, 32, 101, 108, 101, 99, 116,
114, 111, 110, 105, 99, 32, 118, 101, 114, 115, 105, 111,
110, 32, 111, 102, 32, 116, 104, 101, 32, 99, 111, 109,
112, 108, 101, 116, 101, 32, 119, 111, 114, 107, 115, 32,
111, 102, 32, 119, 105, 108, 108, 105, 97, 109, 10, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32),
DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))]
g = itertools.cycle(data_generator(data[:3].copy(), batch_sz=2, max_len=100, shuffle=False))
for i in range(5):
print(i, next(g))
0 (DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))
1 (DeviceArray([[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))
2 (DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))
3 (DeviceArray([[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[108, 105, 98, 114, 97, 114, 121, 32, 111, 102, 32, 116,
104, 101, 32, 102, 117, 116, 117, 114, 101, 32, 97, 110,
100, 32, 115, 104, 97, 107, 101, 115, 112, 101, 97, 114,
101, 32, 99, 100, 114, 111, 109, 115, 46, 32, 32, 112,
114, 111, 106, 101, 99, 116, 32, 103, 117, 116, 101, 110,
98, 101, 114, 103, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))
4 (DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[116, 104, 105, 115, 32, 105, 115, 32, 116, 104, 101, 32,
49, 48, 48, 116, 104, 32, 101, 116, 101, 120, 116, 32,
102, 105, 108, 101, 32, 112, 114, 101, 115, 101, 110, 116,
101, 100, 32, 98, 121, 32, 112, 114, 111, 106, 101, 99,
116, 32, 103, 117, 116, 101, 110, 98, 101, 114, 103, 44,
32, 97, 110, 100, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[105, 115, 32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
32, 105, 110, 32, 99, 111, 111, 112, 101, 114, 97, 116,
105, 111, 110, 32, 119, 105, 116, 104, 32, 119, 111, 114,
108, 100, 32, 108, 105, 98, 114, 97, 114, 121, 44, 32,
105, 110, 99, 46, 44, 32, 102, 114, 111, 109, 32, 116,
104, 101, 105, 114, 10, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32))
def GRULM(vocab_sz=216, d_model=512, n_layers=2, mode='train'):
shift_layer = tl.ShiftRight(mode=mode)
emb_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=d_model)
gru_layers = [tl.GRU(n_units=d_model) for _ in range(n_layers)]
dense_layer = tl.Dense(n_units=vocab_sz)
log_softmax_layer = tl.LogSoftmax()
model = tl.Serial(
shift_layer,
emb_layer,
gru_layers,
dense_layer,
log_softmax_layer
)
return model
GRULM() # Shifts the text by 1
Serial[
Serial[
ShiftRight(1)
]
Embedding_216_512
GRU_512
GRU_512
Dense_216
LogSoftmax
]
def train_model(model, data_generator, lines, eval_lines, batch_sz=32, max_len=64, n_steps=3, output_dir='model/.'):
train_gen = itertools.cycle(data_generator(lines, batch_sz=batch_sz, max_len=max_len, shuffle=True))
eval_gen = itertools.cycle(data_generator(eval_lines, batch_sz=batch_sz, max_len=max_len, shuffle=False))
train_task = training.TrainTask(
labeled_data=train_gen, # Use infinite train data generator
loss_layer=tl.CrossEntropyLoss(), # Don't forget to instantiate this object
optimizer=trax.optimizers.Adam(0.01) # Don't forget to add the learning rate parameter TO 0.0005
)
eval_task = training.EvalTask(
labeled_data=eval_gen, # Use infinite eval data generator
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], # Don't forget to instantiate these objects
n_eval_batches=3 # For better evaluation accuracy in reasonable time
)
training_loop = training.Loop(model,
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
training_loop.run(n_steps=n_steps)
return training_loop
training_loop = train_model(GRULM(), data_generator, lines=data_model, eval_lines=data_blind, n_steps=1000)
/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
Step 1: Total number of trainable weights: 3370200
Step 1: Ran 1 train steps in 9.98 secs
Step 1: train CrossEntropyLoss | 5.37536573
Step 1: eval CrossEntropyLoss | 5.28850317
Step 1: eval Accuracy | 0.21879732
Step 100: Ran 99 train steps in 168.48 secs
Step 100: train CrossEntropyLoss | 3.27563405
Step 100: eval CrossEntropyLoss | 2.82953095
Step 100: eval Accuracy | 0.21625705
Step 200: Ran 100 train steps in 169.83 secs
Step 200: train CrossEntropyLoss | 2.81332421
Step 200: eval CrossEntropyLoss | 2.68687224
Step 200: eval Accuracy | 0.25851414
Step 300: Ran 100 train steps in 168.30 secs
Step 300: train CrossEntropyLoss | 2.72730327
Step 300: eval CrossEntropyLoss | 2.63422775
Step 300: eval Accuracy | 0.26217009
Step 400: Ran 100 train steps in 169.75 secs
Step 400: train CrossEntropyLoss | 3.07694721
Step 400: eval CrossEntropyLoss | 3.22344669
Step 400: eval Accuracy | 0.21283427
Step 500: Ran 100 train steps in 169.61 secs
Step 500: train CrossEntropyLoss | 2.92221594
Step 500: eval CrossEntropyLoss | 2.89374344
Step 500: eval Accuracy | 0.21506842
Step 600: Ran 100 train steps in 170.35 secs
Step 600: train CrossEntropyLoss | 2.84008431
Step 600: eval CrossEntropyLoss | 2.91676569
Step 600: eval Accuracy | 0.20170474
Step 700: Ran 100 train steps in 169.49 secs
Step 700: train CrossEntropyLoss | 2.75251889
Step 700: eval CrossEntropyLoss | 2.87050303
Step 700: eval Accuracy | 0.21942994
Step 800: Ran 100 train steps in 169.41 secs
Step 800: train CrossEntropyLoss | 2.68192506
Step 800: eval CrossEntropyLoss | 2.79392648
Step 800: eval Accuracy | 0.25168787
Step 900: Ran 100 train steps in 169.13 secs
Step 900: train CrossEntropyLoss | 2.66031885
Step 900: eval CrossEntropyLoss | 2.75871555
Step 900: eval Accuracy | 0.23988337