|
| 1 | +import numpy as np |
| 2 | +from tensorflow.python.keras.utils import Sequence, to_categorical |
| 3 | +from tensorflow.python.keras.preprocessing.sequence import pad_sequences |
| 4 | + |
| 5 | + |
| 6 | +class EchoData(Sequence): |
| 7 | + |
| 8 | + def __init__(self, series_length=40000, batch_size=32, |
| 9 | + echo_step=3, truncated_length=10, seed=None): |
| 10 | + |
| 11 | + self.series_length = series_length |
| 12 | + self.truncated_length = truncated_length |
| 13 | + self.n_batches = series_length//truncated_length |
| 14 | + |
| 15 | + self.echo_step = echo_step |
| 16 | + self.batch_size = batch_size |
| 17 | + if seed is not None: |
| 18 | + np.random.seed(seed) |
| 19 | + self.raw_x = None |
| 20 | + self.raw_y = None |
| 21 | + self.x_batches = [] |
| 22 | + self.y_batches = [] |
| 23 | + self.generate_new_series() |
| 24 | + self.prepare_batches() |
| 25 | + |
| 26 | + def __getitem__(self, index): |
| 27 | + if index == 0: |
| 28 | + self.generate_new_series() |
| 29 | + self.prepare_batches() |
| 30 | + return self.x_batches[index], self.y_batches[index] |
| 31 | + |
| 32 | + def __len__(self): |
| 33 | + return self.n_batches |
| 34 | + |
| 35 | + def generate_new_series(self): |
| 36 | + x = np.random.choice( |
| 37 | + 2, |
| 38 | + size=(self.batch_size, self.series_length), |
| 39 | + p=[0.5, 0.5]) |
| 40 | + y = np.roll(x, self.echo_step, axis=1) |
| 41 | + y[:, 0:self.echo_step] = 0 |
| 42 | + self.raw_x = x |
| 43 | + self.raw_y = y |
| 44 | + |
| 45 | + def prepare_batches(self): |
| 46 | + x = np.expand_dims(self.raw_x, axis=-1) |
| 47 | + y = np.expand_dims(self.raw_y, axis=-1) |
| 48 | + self.x_batches = np.split(x, self.n_batches, axis=1) |
| 49 | + self.y_batches = np.split(y, self.n_batches, axis=1) |
| 50 | + |
| 51 | + |
| 52 | +class TemporalOrderExp6aSequence(Sequence): |
| 53 | + """ |
| 54 | + From Hochreiter&Schmidhuber(1997): |
| 55 | +
|
| 56 | + The goal is to classify sequences. Elements and targets are represented locally |
| 57 | + (input vectors with only one non-zero bit). The sequence starts with an E, ends |
| 58 | + with a B (the "trigger symbol") and otherwise consists of randomly chosen symbols |
| 59 | + from the set {a, b, c, d} except for two elements at positions t1 and t2 that are |
| 60 | + either X or Y . The sequence length is randomly chosen between 100 and 110, t1 is |
| 61 | + randomly chosen between 10 and 20, and t2 is randomly chosen between 50 and 60. |
| 62 | + There are 4 sequence classes Q, R, S, U which depend on the temporal order of X and Y. |
| 63 | + The rules are: |
| 64 | + X, X -> Q, |
| 65 | + X, Y -> R, |
| 66 | + Y , X -> S, |
| 67 | + Y , Y -> U. |
| 68 | +
|
| 69 | + """ |
| 70 | + |
| 71 | + def __init__(self, length_range=(100, 111), t1_range=(10, 21), t2_range=(50, 61), |
| 72 | + batch_size=32, seed=None): |
| 73 | + |
| 74 | + self.classes = ['Q', 'R', 'S', 'U'] |
| 75 | + self.n_classes = len(self.classes) |
| 76 | + |
| 77 | + self.relevant_symbols = ['X', 'Y'] |
| 78 | + self.distraction_symbols = ['a', 'b', 'c', 'd'] |
| 79 | + self.start_symbol = 'B' |
| 80 | + self.end_symbol = 'E' |
| 81 | + |
| 82 | + self.length_range = length_range |
| 83 | + self.t1_range = t1_range |
| 84 | + self.t2_range = t2_range |
| 85 | + self.batch_size = batch_size |
| 86 | + |
| 87 | + if seed is not None: |
| 88 | + np.random.seed(seed) |
| 89 | + |
| 90 | + all_symbols = self.relevant_symbols + self.distraction_symbols + \ |
| 91 | + [self.start_symbol] + [self.end_symbol] |
| 92 | + self.n_symbols = len(all_symbols) |
| 93 | + self.s_to_idx = {s: n for n, s in enumerate(all_symbols)} |
| 94 | + self.idx_to_s = {n: s for n, s in enumerate(all_symbols)} |
| 95 | + |
| 96 | + self.c_to_idx = {c: n for n, c in enumerate(self.classes)} |
| 97 | + self.idx_to_c = {n: c for n, c in enumerate(self.classes)} |
| 98 | + |
| 99 | + def generate_pair(self): |
| 100 | + length = np.random.randint(self.length_range[0], self.length_range[1]) |
| 101 | + t1 = np.random.randint(self.t1_range[0], self.t1_range[1]) |
| 102 | + t2 = np.random.randint(self.t2_range[0], self.t2_range[1]) |
| 103 | + |
| 104 | + x = np.random.choice(self.distraction_symbols, length) |
| 105 | + x[0] = self.start_symbol |
| 106 | + x[-1] = self.end_symbol |
| 107 | + |
| 108 | + y = np.random.choice(self.classes) |
| 109 | + if y == 'Q': |
| 110 | + x[t1], x[t2] = self.relevant_symbols[0], self.relevant_symbols[0] |
| 111 | + elif y == 'R': |
| 112 | + x[t1], x[t2] = self.relevant_symbols[0], self.relevant_symbols[1] |
| 113 | + elif y == 'S': |
| 114 | + x[t1], x[t2] = self.relevant_symbols[1], self.relevant_symbols[0] |
| 115 | + else: |
| 116 | + x[t1], x[t2] = self.relevant_symbols[1], self.relevant_symbols[1] |
| 117 | + |
| 118 | + return ''.join(x), y |
| 119 | + |
| 120 | + # encoding/decoding single instance version |
| 121 | + |
| 122 | + def encode_x(self, x): |
| 123 | + idx_x = [self.s_to_idx[s] for s in x] |
| 124 | + return to_categorical(idx_x, num_classes=self.n_symbols) |
| 125 | + |
| 126 | + def encode_y(self, y): |
| 127 | + idx_y = self.c_to_idx[y] |
| 128 | + return to_categorical(idx_y, num_classes=self.n_classes) |
| 129 | + |
| 130 | + def decode_x(self, x): |
| 131 | + x = x[np.sum(x, axis=1) > 0] # remove padding |
| 132 | + return ''.join([self.idx_to_s[pos] for pos in np.argmax(x, axis=1)]) |
| 133 | + |
| 134 | + def decode_y(self, y): |
| 135 | + return self.idx_to_c[np.argmax(y)] |
| 136 | + |
| 137 | + # encoding/decoding batch versions |
| 138 | + |
| 139 | + def encode_x_batch(self, x_batch): |
| 140 | + return pad_sequences([self.encode_x(x) for x in x_batch], |
| 141 | + maxlen=self.length_range[1]) |
| 142 | + |
| 143 | + def encode_y_batch(self, y_batch): |
| 144 | + return np.array([self.encode_y(y) for y in y_batch]) |
| 145 | + |
| 146 | + def decode_x_batch(self, x_batch): |
| 147 | + return [self.decode_x(x) for x in x_batch] |
| 148 | + |
| 149 | + def decode_y_batch(self, y_batch): |
| 150 | + return [self.idx_to_c[pos] for pos in np.argmax(y_batch, axis=1)] |
| 151 | + |
| 152 | + def __len__(self): |
| 153 | + """ Let's assume 1000 sequences as the size of data. """ |
| 154 | + return int(1000. / self.batch_size) |
| 155 | + |
| 156 | + def __getitem__(self, index): |
| 157 | + batch_x, batch_y = [], [] |
| 158 | + for _ in range(self.batch_size): |
| 159 | + x, y = self.generate_pair() |
| 160 | + batch_x.append(x) |
| 161 | + batch_y.append(y) |
| 162 | + return self.encode_x_batch(batch_x), self.encode_y_batch(batch_y) |
| 163 | + |
| 164 | + class DifficultyLevel: |
| 165 | + """ On HARD, settings are identical to the original settings from the '97 paper.""" |
| 166 | + EASY, NORMAL, MODERATE, HARD, NIGHTMARE = range(5) |
| 167 | + |
| 168 | + @staticmethod |
| 169 | + def get_predefined_generator(difficulty_level, batch_size=32, seed=8382): |
| 170 | + EASY = TemporalOrderExp6aSequence.DifficultyLevel.EASY |
| 171 | + NORMAL = TemporalOrderExp6aSequence.DifficultyLevel.NORMAL |
| 172 | + MODERATE = TemporalOrderExp6aSequence.DifficultyLevel.MODERATE |
| 173 | + HARD = TemporalOrderExp6aSequence.DifficultyLevel.HARD |
| 174 | + |
| 175 | + if difficulty_level == EASY: |
| 176 | + length_range = (7, 9) |
| 177 | + t1_range = (1, 3) |
| 178 | + t2_range = (4, 6) |
| 179 | + elif difficulty_level == NORMAL: |
| 180 | + length_range = (30, 41) |
| 181 | + t1_range = (2, 6) |
| 182 | + t2_range = (20, 28) |
| 183 | + elif difficulty_level == MODERATE: |
| 184 | + length_range = (60, 81) |
| 185 | + t1_range = (10, 21) |
| 186 | + t2_range = (45, 55) |
| 187 | + elif difficulty_level == HARD: |
| 188 | + length_range = (100, 111) |
| 189 | + t1_range = (10, 21) |
| 190 | + t2_range = (50, 61) |
| 191 | + else: |
| 192 | + length_range = (300, 501) |
| 193 | + t1_range = (10, 81) |
| 194 | + t2_range = (250, 291) |
| 195 | + return TemporalOrderExp6aSequence(length_range, t1_range, t2_range, |
| 196 | + batch_size, seed) |
0 commit comments