Character Encoding, GRU and Language Model

Open In Colab

import sys
IN_COLAB = 'google.colab' in sys.modules
IN_COLAB
True
if IN_COLAB:
  !pip install -Uqq trax
  !pip install -Uqq fastcore
else:
  print("Running locally")
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 637.9/637.9 KB 10.2 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.8/5.8 MB 71.7 MB/s eta 0:00:00
import os 
import nltk
import re
import string
import itertools
import requests
import pandas as pd
import trax
import trax.fastmath.numpy as np
import random as rnd
from trax import layers as tl
from trax import shapes
from trax.supervised import training
from sklearn.model_selection import train_test_split
from fastcore.all import *
import io
import string
rnd.seed(32)

Read Data and create train test split

url = "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
resp = requests.get(url)

s = io.BytesIO(resp.content)
s.seek(0)
wrapper = io.TextIOWrapper(s)
lines = wrapper.readlines()
len(lines)
124456
# resp.content
lines[-10:]
['PERSONAL USE ONLY, AND (2) ARE NOT DISTRIBUTED OR USED\n',
 'COMMERCIALLY.  PROHIBITED COMMERCIAL DISTRIBUTION INCLUDES BY ANY\n',
 'SERVICE THAT CHARGES FOR DOWNLOAD TIME OR FOR MEMBERSHIP.>>\n',
 '\n',
 '\n',
 '\n',
 'End of this Etext of The Complete Works of William Shakespeare\n',
 '\n',
 '\n',
 '\n']
lines = [ l.lower() for l in lines]
data_model, data_blind = lines[:-1250], lines[-1250:]
data_model[-1]
'                           exeunt florizel, perdita, and camillo\n'
data_blind[0]
'  autolycus. i understand the business, i hear it. to have an open\n'

Convert line to tensor

print(f"ord('a'): {ord('a')}")
print(f"ord('b'): {ord('b')}")
print(f"ord('c'): {ord('c')}")
print(f"ord(' '): {ord(' ')}")
print(f"ord('x'): {ord('x')}")
print(f"ord('y'): {ord('y')}")
print(f"ord('z'): {ord('z')}")
print(f"ord('1'): {ord('1')}")
print(f"ord('2'): {ord('2')}")
print(f"ord('3'): {ord('3')}")
ord('a'): 97
ord('b'): 98
ord('c'): 99
ord(' '): 32
ord('x'): 120
ord('y'): 121
ord('z'): 122
ord('1'): 49
ord('2'): 50
ord('3'): 51
for i in (string.ascii_lowercase + string.digits + string.punctuation):
  print(f"ord({i}): {ord(i)}")
ord(a): 97
ord(b): 98
ord(c): 99
ord(d): 100
ord(e): 101
ord(f): 102
ord(g): 103
ord(h): 104
ord(i): 105
ord(j): 106
ord(k): 107
ord(l): 108
ord(m): 109
ord(n): 110
ord(o): 111
ord(p): 112
ord(q): 113
ord(r): 114
ord(s): 115
ord(t): 116
ord(u): 117
ord(v): 118
ord(w): 119
ord(x): 120
ord(y): 121
ord(z): 122
ord(0): 48
ord(1): 49
ord(2): 50
ord(3): 51
ord(4): 52
ord(5): 53
ord(6): 54
ord(7): 55
ord(8): 56
ord(9): 57
ord(!): 33
ord("): 34
ord(#): 35
ord($): 36
ord(%): 37
ord(&): 38
ord('): 39
ord((): 40
ord()): 41
ord(*): 42
ord(+): 43
ord(,): 44
ord(-): 45
ord(.): 46
ord(/): 47
ord(:): 58
ord(;): 59
ord(<): 60
ord(=): 61
ord(>): 62
ord(?): 63
ord(@): 64
ord([): 91
ord(\): 92
ord(]): 93
ord(^): 94
ord(_): 95
ord(`): 96
ord({): 123
ord(|): 124
ord(}): 125
ord(~): 126
ord('\n')
10
chr(ord('a'))
'a'
chr(1)
'\x01'
def line2tensor(line, EOS_int=1):
  tensor = [ord(c) for c in line]
  tensor.append(EOS_int)
  return tensor

lines[0], line2tensor(lines[0])
('this is the 100th etext file presented by project gutenberg, and\n',
 [116,
  104,
  105,
  115,
  32,
  105,
  115,
  32,
  116,
  104,
  101,
  32,
  49,
  48,
  48,
  116,
  104,
  32,
  101,
  116,
  101,
  120,
  116,
  32,
  102,
  105,
  108,
  101,
  32,
  112,
  114,
  101,
  115,
  101,
  110,
  116,
  101,
  100,
  32,
  98,
  121,
  32,
  112,
  114,
  111,
  106,
  101,
  99,
  116,
  32,
  103,
  117,
  116,
  101,
  110,
  98,
  101,
  114,
  103,
  44,
  32,
  97,
  110,
  100,
  10,
  1])
line2tensor('abc xyz')
[97, 98, 99, 32, 120, 121, 122, 1]

Data Generator

# def data_generator(data, batch_sz,  cb_process, max_len, shuffle=False):
#   while True:
#     index = 0                             # Resets index to zero on data exhaustion 
#     if shuffle : df = df.sample(frac=1.0) # Shuffles data - only relevant for train not eval tasks
#     itr = itertools.cycle(df.iterrows())  # Only purpose of cycle here is to handle last batch case when elements of dataset has been exhausted
#     while index <= len(df):
#       batch = (next(itr) for i in range(batch_sz))
#       index += batch_sz
#       yield process_batch(batch, func_text2idx, pad_id, class_dict, class_weights)
#     if loop: continue 
#     else: break


def batched(iterable, n):
    "Batch data into tuples of length n. The last batch may be shorter."
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1: raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(itertools.islice(it, n))): 
      yield batch

batched('ABCDEFG', 3) 


def text_data_generator(data, batch_sz, max_len, shuffle=True):
  n_lines = len(data); n_lines
  idx_lines = list(range(n_lines))
  if shuffle: rnd.shuffle(idx_lines); idx_lines[:10]
  line_gen = (data[i] for i in idx_lines)
  filter_line_gen = filter(lambda line: len(line) < max_len, line_gen)
  yield from batched(filter_line_gen, batch_sz)
  
def data_generator(data, batch_sz, max_len, cb_process=line2tensor, shuffle=True):
  num_batch_arr = None
  mask_batch_arr = None
  g = text_data_generator(data, batch_sz, max_len, shuffle=shuffle)
  # if infinite: g = itertools.cycle(g)
  for batch in g:
     num_batch_arr = np.array([cb_process(li)+[0]*(max_len - (len(li)+1)) for li in batch])  # (len(li)+1) because line2tensor adds an EOS_int
     mask_batch_arr = np.array([[1]*(len(li)+1)+[0]*(max_len - (len(li)+1)) for li in batch])
     yield num_batch_arr, num_batch_arr, mask_batch_arr
data = data_model
shuffle = True
max_len = 10
data[:10]
idx_lines_tmp = list(range(10)); idx_lines_tmp
if shuffle: rnd.shuffle(idx_lines_tmp)
data[:10], idx_lines_tmp
(['this is the 100th etext file presented by project gutenberg, and\n',
  'is presented in cooperation with world library, inc., from their\n',
  'library of the future and shakespeare cdroms.  project gutenberg\n',
  'often releases etexts that are not placed in the public domain!!\n',
  '\n',
  'shakespeare\n',
  '\n',
  '*this etext has certain copyright implications you should read!*\n',
  '\n',
  '<<this electronic version of the complete works of william\n'],
 [4, 6, 0, 8, 9, 5, 7, 2, 3, 1])
g = text_data_generator(data[:10].copy(), batch_sz=4, max_len=100,  shuffle=False)
list(g)
[('this is the 100th etext file presented by project gutenberg, and\n',
  'is presented in cooperation with world library, inc., from their\n',
  'library of the future and shakespeare cdroms.  project gutenberg\n',
  'often releases etexts that are not placed in the public domain!!\n'),
 ('\n',
  'shakespeare\n',
  '\n',
  '*this etext has certain copyright implications you should read!*\n'),
 ('\n', '<<this electronic version of the complete works of william\n')]
g = text_data_generator(data[:10].copy(), batch_sz=4, max_len=100, shuffle=True)
index=0
stop_id = 10
for batch in itertools.cycle(g):
  print(batch)
  index +=1
  if index > stop_id: break
('library of the future and shakespeare cdroms.  project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms.  project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms.  project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
('shakespeare\n', 'is presented in cooperation with world library, inc., from their\n')
('library of the future and shakespeare cdroms.  project gutenberg\n', '*this etext has certain copyright implications you should read!*\n', '<<this electronic version of the complete works of william\n', 'often releases etexts that are not placed in the public domain!!\n')
('\n', 'this is the 100th etext file presented by project gutenberg, and\n', '\n', '\n')
g = data_generator(data[:10].copy(), batch_sz=3, max_len=100,  shuffle=False)
list(g)
[(DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
                 49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
                102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
                101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
                116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
                 32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
                 32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
                105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
                108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
                105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
                104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
                104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
                100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
                101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
                114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
                 98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
                 49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
                102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
                101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
                116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
                 32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
                 32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
                105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
                108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
                105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
                104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
                104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
                100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
                101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
                114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
                 98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32)),
 (DeviceArray([[111, 102, 116, 101, 110,  32, 114, 101, 108, 101,  97, 115,
                101, 115,  32, 101, 116, 101, 120, 116, 115,  32, 116, 104,
                 97, 116,  32,  97, 114, 101,  32, 110, 111, 116,  32, 112,
                108,  97,  99, 101, 100,  32, 105, 110,  32, 116, 104, 101,
                 32, 112, 117,  98, 108, 105,  99,  32, 100, 111, 109,  97,
                105, 110,  33,  33,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [115, 104,  97, 107, 101, 115, 112, 101,  97, 114, 101,  10,
                  1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[111, 102, 116, 101, 110,  32, 114, 101, 108, 101,  97, 115,
                101, 115,  32, 101, 116, 101, 120, 116, 115,  32, 116, 104,
                 97, 116,  32,  97, 114, 101,  32, 110, 111, 116,  32, 112,
                108,  97,  99, 101, 100,  32, 105, 110,  32, 116, 104, 101,
                 32, 112, 117,  98, 108, 105,  99,  32, 100, 111, 109,  97,
                105, 110,  33,  33,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [115, 104,  97, 107, 101, 115, 112, 101,  97, 114, 101,  10,
                  1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32)),
 (DeviceArray([[ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 42, 116, 104, 105, 115,  32, 101, 116, 101, 120, 116,  32,
                104,  97, 115,  32,  99, 101, 114, 116,  97, 105, 110,  32,
                 99, 111, 112, 121, 114, 105, 103, 104, 116,  32, 105, 109,
                112, 108, 105,  99,  97, 116, 105, 111, 110, 115,  32, 121,
                111, 117,  32, 115, 104, 111, 117, 108, 100,  32, 114, 101,
                 97, 100,  33,  42,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 42, 116, 104, 105, 115,  32, 101, 116, 101, 120, 116,  32,
                104,  97, 115,  32,  99, 101, 114, 116,  97, 105, 110,  32,
                 99, 111, 112, 121, 114, 105, 103, 104, 116,  32, 105, 109,
                112, 108, 105,  99,  97, 116, 105, 111, 110, 115,  32, 121,
                111, 117,  32, 115, 104, 111, 117, 108, 100,  32, 114, 101,
                 97, 100,  33,  42,  10,   1,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0],
               [ 10,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
               [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32)),
 (DeviceArray([[ 60,  60, 116, 104, 105, 115,  32, 101, 108, 101,  99, 116,
                114, 111, 110, 105,  99,  32, 118, 101, 114, 115, 105, 111,
                110,  32, 111, 102,  32, 116, 104, 101,  32,  99, 111, 109,
                112, 108, 101, 116, 101,  32, 119, 111, 114, 107, 115,  32,
                111, 102,  32, 119, 105, 108, 108, 105,  97, 109,  10,   1,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[ 60,  60, 116, 104, 105, 115,  32, 101, 108, 101,  99, 116,
                114, 111, 110, 105,  99,  32, 118, 101, 114, 115, 105, 111,
                110,  32, 111, 102,  32, 116, 104, 101,  32,  99, 111, 109,
                112, 108, 101, 116, 101,  32, 119, 111, 114, 107, 115,  32,
                111, 102,  32, 119, 105, 108, 108, 105,  97, 109,  10,   1,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                  0,   0,   0,   0]], dtype=int32),
  DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))]
g = itertools.cycle(data_generator(data[:3].copy(), batch_sz=2, max_len=100,  shuffle=False))
for i in range(5):
  print(i, next(g))
0 (DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))
1 (DeviceArray([[108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
              104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
              100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
              101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
              114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
               98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
              104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
              100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
              101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
              114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
               98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))
2 (DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))
3 (DeviceArray([[108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
              104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
              100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
              101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
              114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
               98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[108, 105,  98, 114,  97, 114, 121,  32, 111, 102,  32, 116,
              104, 101,  32, 102, 117, 116, 117, 114, 101,  32,  97, 110,
              100,  32, 115, 104,  97, 107, 101, 115, 112, 101,  97, 114,
              101,  32,  99, 100, 114, 111, 109, 115,  46,  32,  32, 112,
              114, 111, 106, 101,  99, 116,  32, 103, 117, 116, 101, 110,
               98, 101, 114, 103,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))
4 (DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[116, 104, 105, 115,  32, 105, 115,  32, 116, 104, 101,  32,
               49,  48,  48, 116, 104,  32, 101, 116, 101, 120, 116,  32,
              102, 105, 108, 101,  32, 112, 114, 101, 115, 101, 110, 116,
              101, 100,  32,  98, 121,  32, 112, 114, 111, 106, 101,  99,
              116,  32, 103, 117, 116, 101, 110,  98, 101, 114, 103,  44,
               32,  97, 110, 100,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0],
             [105, 115,  32, 112, 114, 101, 115, 101, 110, 116, 101, 100,
               32, 105, 110,  32,  99, 111, 111, 112, 101, 114,  97, 116,
              105, 111, 110,  32, 119, 105, 116, 104,  32, 119, 111, 114,
              108, 100,  32, 108, 105,  98, 114,  97, 114, 121,  44,  32,
              105, 110,  99,  46,  44,  32, 102, 114, 111, 109,  32, 116,
              104, 101, 105, 114,  10,   1,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0]], dtype=int32), DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],            dtype=int32))

Model Definition

def GRULM(vocab_sz=216, d_model=512, n_layers=2, mode='train'):
  shift_layer = tl.ShiftRight(mode=mode)
  emb_layer = tl.Embedding(vocab_size=vocab_sz, d_feature=d_model)
  gru_layers = [tl.GRU(n_units=d_model) for _ in range(n_layers)]
  dense_layer = tl.Dense(n_units=vocab_sz)
  log_softmax_layer = tl.LogSoftmax()

  model = tl.Serial(
      shift_layer,
      emb_layer,
      gru_layers,
      dense_layer,
      log_softmax_layer
  )

  return model


GRULM() # Shifts the text by 1
Serial[
  Serial[
    ShiftRight(1)
  ]
  Embedding_216_512
  GRU_512
  GRU_512
  Dense_216
  LogSoftmax
]

Count Number of Batches

max_len = 64
batch_sz = 128
filter_line_gen = filter(lambda line: len(line) < max_len, data_model)
n = len(list(filter_line_gen))
steps_epoch = int(n/batch_sz); steps_epoch
n, steps_epoch
(107723, 841)

Model Training

def train_model(model, data_generator, lines, eval_lines, batch_sz=32, max_len=64, n_steps=3, output_dir='model/.'):
  train_gen = itertools.cycle(data_generator(lines, batch_sz=batch_sz, max_len=max_len,  shuffle=True))
  eval_gen  = itertools.cycle(data_generator(eval_lines, batch_sz=batch_sz, max_len=max_len,  shuffle=False))
  train_task = training.TrainTask( 
    labeled_data=train_gen, # Use infinite train data generator
    loss_layer=tl.CrossEntropyLoss(),   # Don't forget to instantiate this object
    optimizer=trax.optimizers.Adam(0.01)     # Don't forget to add the learning rate parameter TO 0.0005
) 

  eval_task = training.EvalTask( 
      labeled_data=eval_gen,    # Use infinite eval data generator
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], # Don't forget to instantiate these objects
      n_eval_batches=3  # For better evaluation accuracy in reasonable time 
  ) 
  training_loop = training.Loop(model, 
                                train_task, 
                                eval_tasks=[eval_task], 
                                output_dir=output_dir) 

  training_loop.run(n_steps=n_steps)
  return training_loop
!rm -rf model
training_loop = train_model(GRULM(), data_generator, lines=data_model, eval_lines=data_blind, n_steps=1000)
/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py:553: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(

Step      1: Total number of trainable weights: 3370200
Step      1: Ran 1 train steps in 9.98 secs
Step      1: train CrossEntropyLoss |  5.37536573
Step      1: eval  CrossEntropyLoss |  5.28850317
Step      1: eval          Accuracy |  0.21879732

Step    100: Ran 99 train steps in 168.48 secs
Step    100: train CrossEntropyLoss |  3.27563405
Step    100: eval  CrossEntropyLoss |  2.82953095
Step    100: eval          Accuracy |  0.21625705

Step    200: Ran 100 train steps in 169.83 secs
Step    200: train CrossEntropyLoss |  2.81332421
Step    200: eval  CrossEntropyLoss |  2.68687224
Step    200: eval          Accuracy |  0.25851414

Step    300: Ran 100 train steps in 168.30 secs
Step    300: train CrossEntropyLoss |  2.72730327
Step    300: eval  CrossEntropyLoss |  2.63422775
Step    300: eval          Accuracy |  0.26217009

Step    400: Ran 100 train steps in 169.75 secs
Step    400: train CrossEntropyLoss |  3.07694721
Step    400: eval  CrossEntropyLoss |  3.22344669
Step    400: eval          Accuracy |  0.21283427

Step    500: Ran 100 train steps in 169.61 secs
Step    500: train CrossEntropyLoss |  2.92221594
Step    500: eval  CrossEntropyLoss |  2.89374344
Step    500: eval          Accuracy |  0.21506842

Step    600: Ran 100 train steps in 170.35 secs
Step    600: train CrossEntropyLoss |  2.84008431
Step    600: eval  CrossEntropyLoss |  2.91676569
Step    600: eval          Accuracy |  0.20170474

Step    700: Ran 100 train steps in 169.49 secs
Step    700: train CrossEntropyLoss |  2.75251889
Step    700: eval  CrossEntropyLoss |  2.87050303
Step    700: eval          Accuracy |  0.21942994

Step    800: Ran 100 train steps in 169.41 secs
Step    800: train CrossEntropyLoss |  2.68192506
Step    800: eval  CrossEntropyLoss |  2.79392648
Step    800: eval          Accuracy |  0.25168787

Step    900: Ran 100 train steps in 169.13 secs
Step    900: train CrossEntropyLoss |  2.66031885
Step    900: eval  CrossEntropyLoss |  2.75871555
Step    900: eval          Accuracy |  0.23988337
training_loop.run(n_steps=2)