Skip to content

Commit ca8fad4

Browse files
authored
Add files via upload
1 parent 531c239 commit ca8fad4

File tree

1 file changed

+196
-0
lines changed

1 file changed

+196
-0
lines changed

pytorch/New_Tuts/sequential_tasks.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import numpy as np
2+
from tensorflow.python.keras.utils import Sequence, to_categorical
3+
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
4+
5+
6+
class EchoData(Sequence):
7+
8+
def __init__(self, series_length=40000, batch_size=32,
9+
echo_step=3, truncated_length=10, seed=None):
10+
11+
self.series_length = series_length
12+
self.truncated_length = truncated_length
13+
self.n_batches = series_length//truncated_length
14+
15+
self.echo_step = echo_step
16+
self.batch_size = batch_size
17+
if seed is not None:
18+
np.random.seed(seed)
19+
self.raw_x = None
20+
self.raw_y = None
21+
self.x_batches = []
22+
self.y_batches = []
23+
self.generate_new_series()
24+
self.prepare_batches()
25+
26+
def __getitem__(self, index):
27+
if index == 0:
28+
self.generate_new_series()
29+
self.prepare_batches()
30+
return self.x_batches[index], self.y_batches[index]
31+
32+
def __len__(self):
33+
return self.n_batches
34+
35+
def generate_new_series(self):
36+
x = np.random.choice(
37+
2,
38+
size=(self.batch_size, self.series_length),
39+
p=[0.5, 0.5])
40+
y = np.roll(x, self.echo_step, axis=1)
41+
y[:, 0:self.echo_step] = 0
42+
self.raw_x = x
43+
self.raw_y = y
44+
45+
def prepare_batches(self):
46+
x = np.expand_dims(self.raw_x, axis=-1)
47+
y = np.expand_dims(self.raw_y, axis=-1)
48+
self.x_batches = np.split(x, self.n_batches, axis=1)
49+
self.y_batches = np.split(y, self.n_batches, axis=1)
50+
51+
52+
class TemporalOrderExp6aSequence(Sequence):
53+
"""
54+
From Hochreiter&Schmidhuber(1997):
55+
56+
The goal is to classify sequences. Elements and targets are represented locally
57+
(input vectors with only one non-zero bit). The sequence starts with an E, ends
58+
with a B (the "trigger symbol") and otherwise consists of randomly chosen symbols
59+
from the set {a, b, c, d} except for two elements at positions t1 and t2 that are
60+
either X or Y . The sequence length is randomly chosen between 100 and 110, t1 is
61+
randomly chosen between 10 and 20, and t2 is randomly chosen between 50 and 60.
62+
There are 4 sequence classes Q, R, S, U which depend on the temporal order of X and Y.
63+
The rules are:
64+
X, X -> Q,
65+
X, Y -> R,
66+
Y , X -> S,
67+
Y , Y -> U.
68+
69+
"""
70+
71+
def __init__(self, length_range=(100, 111), t1_range=(10, 21), t2_range=(50, 61),
72+
batch_size=32, seed=None):
73+
74+
self.classes = ['Q', 'R', 'S', 'U']
75+
self.n_classes = len(self.classes)
76+
77+
self.relevant_symbols = ['X', 'Y']
78+
self.distraction_symbols = ['a', 'b', 'c', 'd']
79+
self.start_symbol = 'B'
80+
self.end_symbol = 'E'
81+
82+
self.length_range = length_range
83+
self.t1_range = t1_range
84+
self.t2_range = t2_range
85+
self.batch_size = batch_size
86+
87+
if seed is not None:
88+
np.random.seed(seed)
89+
90+
all_symbols = self.relevant_symbols + self.distraction_symbols + \
91+
[self.start_symbol] + [self.end_symbol]
92+
self.n_symbols = len(all_symbols)
93+
self.s_to_idx = {s: n for n, s in enumerate(all_symbols)}
94+
self.idx_to_s = {n: s for n, s in enumerate(all_symbols)}
95+
96+
self.c_to_idx = {c: n for n, c in enumerate(self.classes)}
97+
self.idx_to_c = {n: c for n, c in enumerate(self.classes)}
98+
99+
def generate_pair(self):
100+
length = np.random.randint(self.length_range[0], self.length_range[1])
101+
t1 = np.random.randint(self.t1_range[0], self.t1_range[1])
102+
t2 = np.random.randint(self.t2_range[0], self.t2_range[1])
103+
104+
x = np.random.choice(self.distraction_symbols, length)
105+
x[0] = self.start_symbol
106+
x[-1] = self.end_symbol
107+
108+
y = np.random.choice(self.classes)
109+
if y == 'Q':
110+
x[t1], x[t2] = self.relevant_symbols[0], self.relevant_symbols[0]
111+
elif y == 'R':
112+
x[t1], x[t2] = self.relevant_symbols[0], self.relevant_symbols[1]
113+
elif y == 'S':
114+
x[t1], x[t2] = self.relevant_symbols[1], self.relevant_symbols[0]
115+
else:
116+
x[t1], x[t2] = self.relevant_symbols[1], self.relevant_symbols[1]
117+
118+
return ''.join(x), y
119+
120+
# encoding/decoding single instance version
121+
122+
def encode_x(self, x):
123+
idx_x = [self.s_to_idx[s] for s in x]
124+
return to_categorical(idx_x, num_classes=self.n_symbols)
125+
126+
def encode_y(self, y):
127+
idx_y = self.c_to_idx[y]
128+
return to_categorical(idx_y, num_classes=self.n_classes)
129+
130+
def decode_x(self, x):
131+
x = x[np.sum(x, axis=1) > 0] # remove padding
132+
return ''.join([self.idx_to_s[pos] for pos in np.argmax(x, axis=1)])
133+
134+
def decode_y(self, y):
135+
return self.idx_to_c[np.argmax(y)]
136+
137+
# encoding/decoding batch versions
138+
139+
def encode_x_batch(self, x_batch):
140+
return pad_sequences([self.encode_x(x) for x in x_batch],
141+
maxlen=self.length_range[1])
142+
143+
def encode_y_batch(self, y_batch):
144+
return np.array([self.encode_y(y) for y in y_batch])
145+
146+
def decode_x_batch(self, x_batch):
147+
return [self.decode_x(x) for x in x_batch]
148+
149+
def decode_y_batch(self, y_batch):
150+
return [self.idx_to_c[pos] for pos in np.argmax(y_batch, axis=1)]
151+
152+
def __len__(self):
153+
""" Let's assume 1000 sequences as the size of data. """
154+
return int(1000. / self.batch_size)
155+
156+
def __getitem__(self, index):
157+
batch_x, batch_y = [], []
158+
for _ in range(self.batch_size):
159+
x, y = self.generate_pair()
160+
batch_x.append(x)
161+
batch_y.append(y)
162+
return self.encode_x_batch(batch_x), self.encode_y_batch(batch_y)
163+
164+
class DifficultyLevel:
165+
""" On HARD, settings are identical to the original settings from the '97 paper."""
166+
EASY, NORMAL, MODERATE, HARD, NIGHTMARE = range(5)
167+
168+
@staticmethod
169+
def get_predefined_generator(difficulty_level, batch_size=32, seed=8382):
170+
EASY = TemporalOrderExp6aSequence.DifficultyLevel.EASY
171+
NORMAL = TemporalOrderExp6aSequence.DifficultyLevel.NORMAL
172+
MODERATE = TemporalOrderExp6aSequence.DifficultyLevel.MODERATE
173+
HARD = TemporalOrderExp6aSequence.DifficultyLevel.HARD
174+
175+
if difficulty_level == EASY:
176+
length_range = (7, 9)
177+
t1_range = (1, 3)
178+
t2_range = (4, 6)
179+
elif difficulty_level == NORMAL:
180+
length_range = (30, 41)
181+
t1_range = (2, 6)
182+
t2_range = (20, 28)
183+
elif difficulty_level == MODERATE:
184+
length_range = (60, 81)
185+
t1_range = (10, 21)
186+
t2_range = (45, 55)
187+
elif difficulty_level == HARD:
188+
length_range = (100, 111)
189+
t1_range = (10, 21)
190+
t2_range = (50, 61)
191+
else:
192+
length_range = (300, 501)
193+
t1_range = (10, 81)
194+
t2_range = (250, 291)
195+
return TemporalOrderExp6aSequence(length_range, t1_range, t2_range,
196+
batch_size, seed)

0 commit comments

Comments
 (0)