diff --git a/.gitignore b/.gitignore index 6c0eec71..4336d66b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ logs html/ diagrams/ .comet.config -settings.md \ No newline at end of file +settings.md +labml_app.log \ No newline at end of file diff --git a/.labml.yaml b/.labml.yaml index 1290b7bf..2e384051 100644 --- a/.labml.yaml +++ b/.labml.yaml @@ -19,3 +19,4 @@ indicators: name: optim.* options: comet: false +web_api: http://localhost:5005/api/v1/track? diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html new file mode 100644 index 00000000..c35bbd9a --- /dev/null +++ b/docs/experiments/arithmetic_dataset.html @@ -0,0 +1,900 @@ + + +
+ + + + + + + + + + + + + + + + + + + ++ home + experiments +
+ +This is based on code by Georges Harik (@gharik).
+ +11import random
+12import string
+13from typing import List
+14
+15import torch
+16from labml.logger import Text
+17from torch.utils.data import DataLoader, Dataset
+18
+19from labml import monit, logger, tracker
+20from labml.configs import option
+21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batchThis creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.
+It's based on a character level tokenization.
+ +24class ArithmeticDataset(Dataset):seq_len is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch34 def __init__(self, seq_len: int, max_digits: int, n_sequences: int):41 self.n_sequences = n_sequences
+42 self.max_digits = max_digits
+43 self.seq_len = seq_lenToken id to string
+ +45 self.itos = list(string.digits + 'xe =\n?+;')Character to token id
+ +47 self.stoi = {c: i for i, c in enumerate(self.itos)} Generates an integer with n_digit
+ number of digits
49 @staticmethod
+50 def make_int(n_digits: int):54 res = 0
+55 for i in range(n_digits):
+56 d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
+57 res = res * 10 + d
+58
+59 return res Generates the workings for x + y
+. For example for 11+29
+ it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
+.
61 @staticmethod
+62 def get_add_explanation(x: int, y: int):69 carry = 0
+70 e = 0
+71 explanation = []
+72 while x > 0 or y > 0 or carry > 0:
+73 rx, ry = x % 10, y % 10
+74 total = rx + ry + carry
+75 explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
+76 x, y, carry = x // 10, y // 10, total // 10
+77 e += 1
+78
+79 return ' '.join(explanation)Make a problem with a pre_explanation or not
+Creates an arithmetic addition problem with workings and answer.
+ +82 def make_add_problem(self):86 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+87 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+88
+89 explanation = self.get_add_explanation(x, y)
+90 return f"x={x}+{y}; {explanation} x=={x + y}\n"Get arithmetic problem and answer. This is used for evaluation.
+ +92 def get_qa(self):96 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+97 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+98
+99 return f'x={x}+{y};', f'{x + y}'Generate multiple problems and pack them into a sequence.
+ +101 def get_packed_math_input(self):105 s_enc = []
+106 while len(s_enc) <= self.seq_len:
+107 s_part = self.make_add_problem()
+108 s_part_enc = self.encode('?' + s_part)
+109 s_enc = s_enc + s_part_enc
+110 return s_encEncode a given string
+ +112 def encode(self, s: str):116 return [self.stoi[c] for c in s]Decode a list of token ids
+ +118 def decode(self, arr: List[int]):122 return ''.join([self.itos[c] for c in arr])Get a input and target pair for auto-regressive modelling
+ +124 def __getitem__(self, idx: int):128 s = torch.tensor(self.get_packed_math_input())
+129 return s[:self.seq_len], s[1:self.seq_len + 1]Number of sequences per epoch
+ +131 def __len__(self):135 return self.n_sequences138class ArithmeticAutoregression(NLPAutoRegressionConfigs):Maximum number of digits per operand integer
+ +143 max_digits: int = 4Number of training sequences per epoch
+ +145 train_sequences_per_epoch: int = 2 ** 12Training data loader
+ +147 train_loader: DataLoader = 'arithmetic_train_loader'Number of problems in evaluation
+ +149 n_tests: int = 64No need of a validation dataset
+ +151 validator = NoneNumber of times to run evaluations per epoch
+ +153 inner_iterations = 4Number of tokens in the vocabulary
+ +155 n_tokens = len(ArithmeticDataset(1, 1, 1).itos)157 @torch.no_grad()
+158 def sample(self):Skip in the first epoch
+ +166 if self.training_loop.idx < 1:
+167 returnCreate a dataset to generate problems
+ +170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)Get a set of problems and answers
+ +172 qa = [dataset.get_qa() for _ in range(self.n_tests)]Collect the problems only
+ +174 questions = [p[0] for p in qa]Create a tensor with only the initial token
+ +177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])Move to device
+ +179 data = data.to(self.device)Number of sequences that have completed
+ +182 finished = torch.zeros((len(questions),)).bool().to(self.device)Token id of the new line character - this marks end of the answer
+ +184 new_line = dataset.stoi['\n']Sampled results
+ +187 results = [p[0] for p in questions]Sample upto sequence length
+ +190 for i in monit.iterate('Sample', self.seq_len - 1):If all the sequences have completed we skip this
+ +192 if finished.sum() == len(finished):
+193 continueGet the model output
+ +196 output, *_ = self.model(data)Get the model prediction (greedy)
+ +198 output = output[-1].argmax(dim=-1)Find which sequences have finished
+ +201 finished = finished | (output == new_line)Skip if all have finished
+ +203 if finished.sum() == len(finished):
+204 continueOverride with the question
+ +207 for j, p in enumerate(questions):
+208 if len(p) > i + 1:
+209 output[j] = dataset.stoi[p[i + 1]]Add the next token to the input
+ +212 data = torch.cat([data, output[None, :]], dim=0)Get the sampled results
+ +215 for j, c in enumerate(output):
+216 results[j] += dataset.itos[c]Discard everything after the answer in the results
+ +219 results = [r.split('\n')[0] for r in results]Log a sample
+ +222 res_sample = results[0].split(';')
+223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])Get the answers
+ +226 results = [r.split('x==')[-1] for r in results]Count the number of correct answers
+ +229 correct = 0
+230 for r, _qa in zip(results, qa):
+231 if r == _qa[1]:
+232 correct += 1Log the score
+ +235 tracker.save('score', correct / len(results))Training data loader
+ +238@option(ArithmeticAutoregression.train_loader)
+239def arithmetic_train_loader(c: ArithmeticAutoregression):243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+244 batch_size=c.batch_size,
+245 collate_fn=transpose_batch,
+246 num_workers=4)Code to test generated problems
+ +249def _test():253 dataset = ArithmeticDataset(256, 8, 10)
+254
+255 print(dataset.decode(dataset.get_packed_math_input()))+ +
259if __name__ == '__main__':
+260 _test()80 accuracy = Accuracy()80 accuracy = AccuracyMovingAvg()15import copy
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 42cf8562..8f6077e4 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -204,7 +204,7 @@
https://nn.labml.ai/normalization/deep_norm/index.html
- 2022-04-23T16:30:00+00:00
+ 2022-05-18T16:30:00+00:00
1.00
@@ -244,6 +244,13 @@
+
+ https://nn.labml.ai/experiments/arithmetic_dataset.html
+ 2022-06-02T16:30:00+00:00
+ 1.00
+
+
+
https://nn.labml.ai/experiments/index.html
2020-12-26T16:30:00+00:00
@@ -603,14 +610,35 @@
https://nn.labml.ai/transformers/rope/index.html
- 2022-04-05T16:30:00+00:00
+ 2022-05-31T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html
+ 2022-06-02T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/transformers/rope/value_pe/index.html
+ 2022-06-02T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/transformers/rope/value_pe/experiment.html
+ 2022-05-31T16:30:00+00:00
1.00
https://nn.labml.ai/transformers/rope/experiment.html
- 2022-03-12T16:30:00+00:00
+ 2022-05-31T16:30:00+00:00
1.00
diff --git a/docs/transformers/rope/experiment.html b/docs/transformers/rope/experiment.html
index 156b2092..5ce9d242 100644
--- a/docs/transformers/rope/experiment.html
+++ b/docs/transformers/rope/experiment.html
@@ -92,7 +92,7 @@ Rotary PE attention
21def _rotary_pe_mha(c: TransformerConfigs):
22 from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
-23 return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
+23 return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.)
46 experiment.create(name="rotary_pe_transformer")46 experiment.create(name="rotary_pe_transformer", writers={'screen'})Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
+Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
-where is a constant angle. The other pairs of features are transformed similarly.
+Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
+where is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions and would be
-This shows that for dot-production attention the rotary encodings gives relative attention.
+For a pair of features, dot-product attention score between two positions and would be
+This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different for each pair.
-The paper suggests using for the pairs of features.
-We pair feature with feature . So for position we transform
-to
- +The paper suggests using for the pairs of features.
+We pair feature with feature . So for position we transform
+to
+d
- is the number of features base
is the constant used for calculating 118 def __init__(self, d: int, base: int = 10_000):119 def __init__(self, d: int, base: int = 10_000):123 super().__init__()124 super().__init__()
+125
+126 self.base = base
+127 self.d = d
+128 self.cos_cached = None
+129 self.sin_cached = None125 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)131 def _build_cache(self, x: torch.Tensor):x
- is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
-Return if cache is already built
127 def forward(self, x: torch.Tensor):136 if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
+137 return132 seq_len, batch_size, n_heads, d = x.shape140 seq_len = x.shape[0]135 d_2 = d // 2143 theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)138 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)146 seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)Calculate the product of position index and
+Calculate the product of position index and
141 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)149 idx_theta = torch.einsum('n,d->nd', seq_idx, theta)Concatenate so that for row we have
+Concatenate so that for row we have
145 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)153 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)148 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)156 self.cos_cached = idx_theta2.cos()[:, None, None, :]
+157 self.sin_cached = idx_theta2.sin()[:, None, None, :]160 rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :])159 def _neg_half(self, x: torch.Tensor):163 return rx161 d_2 = self.d // 2We override multi-head attention from original transformer.
+Calculate
166class RotaryPEMultiHeadAttention(MultiHeadAttention):164 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)x
+ is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
+173 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):166 def forward(self, x: torch.Tensor):The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value
- might make sense.
Cache and values
177 super().__init__(heads, d_model, dropout_prob, bias=False)171 self._build_cache(x)Rotary positional embedding layers
+Split the features, we can choose to apply rotary embeddings only to a partial set of features.
180 self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
-181 self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)174 x_rope, x_pass = x[..., :self.d], x[..., self.d:]183 def get_scores(self, query: torch.Tensor, key: torch.Tensor):178 neg_half_x = self._neg_half(x_rope)Calculate dot-product with RoPE
+Calculate
+for
189 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))190 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])192def _test_rotary():193 return torch.cat((x_rope, x_pass), dim=-1)We override multi-head attention from original transformer.
+ +196class RotaryPEMultiHeadAttention(MultiHeadAttention):203 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
+204 super().__init__(heads, d_model, dropout_prob)Rotary positional embedding layers
+ +207 d_rope = int(self.d_k * rope_percentage)
+208 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+209 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)211 def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate dot-product with RoPE
+ +217 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))Testing RoPE with a simple example
+ +220def _test_rotary():196 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-197 x = x[:, None, None, :]
-198 inspect(x)
-199
-200 rotary_pe = RotaryPositionalEmbeddings(3)
-201 inspect(rotary_pe(x))
-202
-203
-204if __name__ == '__main__':
-205 _test_rotary()224 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+225 x = x[:, None, None, :]
+226 inspect(x)
+227
+228 rotary_pe = RotaryPositionalEmbeddings(3)
+229 inspect(rotary_pe(x))
+230
+231
+232if __name__ == '__main__':
+233 _test_rotary()