From 367da28267ec107d26e2c057746bfd4daf3c1ced Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 08:32:12 +0530 Subject: [PATCH 01/27] arithmetic dataset --- labml_nn/experiments/arithmetic_dataset.py | 107 +++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 labml_nn/experiments/arithmetic_dataset.py diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py new file mode 100644 index 00000000..c4b7d3f0 --- /dev/null +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -0,0 +1,107 @@ +""" +This is based on code by [@gharik](https://twitter.com/gharik). +""" + +import random +import string +from typing import List + +import torch +from torch.utils.data import DataLoader, Dataset + +from labml.configs import option +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch + + +class ArithmeticDataset(Dataset): + def __init__(self, seq_len: int, max_digits: int, n_sequences: int): + self.n_sequences = n_sequences + self.max_digits = max_digits + self.seq_len = seq_len + self.itos = list(string.digits + 'xe =\n?+;') + self.stoi = {c: i for i, c in enumerate(self.itos)} + + @staticmethod + def make_int(n_digits): + res = 0 + for i in range(n_digits): + d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11) + res = res * 10 + d + + return res + + @staticmethod + def get_add_explanation(x, y): + carry = 0 + e = 0 + explanation = [] + while x > 0 or y > 0 or carry > 0: + rx, ry = x % 10, y % 10 + total = rx + ry + carry + explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}") + x, y, c = x // 10, y // 10, total // 10 + e += 1 + + return ' '.join(explanation) + + # Make a problem with a pre_explanation or not + def make_add_problem(self): + x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + + explanation = self.get_add_explanation(x, y) + return f"x={x}+{y}; {explanation} x=={x + y}\n" + + def get_packed_math_input(self): + s = "" + s_enc = [] + while len(s_enc) <= self.seq_len: + s_part = self.make_add_problem() + s_part_enc = self.encode('?' + s_part) + s_enc = s_enc + s_part_enc + return s_enc + + def encode(self, s: str): + return [self.stoi[c] for c in s] + + def decode(self, arr: List[int]): + return ''.join([self.itos[c] for c in arr]) + + def __getitem__(self, idx): + s = torch.tensor(self.get_packed_math_input()) + return s[:self.seq_len], s[1:self.seq_len + 1] + + def __len__(self): + return self.n_sequences + + +class ArithmeticAutoregression(NLPAutoRegressionConfigs): + max_digits: int = 4 + train_sequences_per_epoch: int = 1024 + valid_sequences_per_epoch: int = 128 + train_loader: DataLoader = 'arithmetic_train_loader' + valid_loader: DataLoader = 'arithmetic_valid_loader' + + +@option(ArithmeticAutoregression.train_loader) +def arithmetic_train_loader(c: ArithmeticAutoregression): + return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch), + batch_size=c.batch_size, + collate_fn=transpose_batch) + + +@option(ArithmeticAutoregression.valid_loader) +def arithmetic_valid_loader(c: ArithmeticAutoregression): + return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch), + batch_size=c.batch_size, + collate_fn=transpose_batch) + + +def _test(): + dataset = ArithmeticDataset(256, 8, 10) + + print(dataset.decode(dataset.get_packed_math_input())) + + +if __name__ == '__main__': + _test() From e56ea23c802a31ea34bf4e197b43cdde6f198188 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 10:34:28 +0530 Subject: [PATCH 02/27] reverse rotate --- labml_nn/transformers/rope/__init__.py | 49 ++++-- labml_nn/transformers/rope/experiment.py | 2 +- .../transformers/rope/value_pe/__init__.py | 163 ++++++++++++++++++ .../transformers/rope/value_pe/experiment.py | 95 ++++++++++ 4 files changed, 293 insertions(+), 16 deletions(-) create mode 100644 labml_nn/transformers/rope/value_pe/__init__.py create mode 100644 labml_nn/transformers/rope/value_pe/experiment.py diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index ae93f859..bbbc5aa1 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -115,37 +115,56 @@ class RotaryPositionalEmbeddings(nn.Module): \end{pmatrix} \\ \end{align} """ + def __init__(self, d: int, base: int = 10_000): """ * `d` is the number of features $d$ * `base` is the constant used for calculating $\Theta$ """ super().__init__() - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False) + self.base = base + self.d = d + self.cos_cached = None + self.sin_cached = None - def forward(self, x: torch.Tensor): - """ - * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` - """ - # Extract the shape - seq_len, batch_size, n_heads, d = x.shape + def _build_cache(self, x: torch.Tensor): + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return - # $\frac{d}{2}$ - d_2 = d // 2 + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta) + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) # Calculate the product of position index and $\theta_i$ - idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta) + idx_theta = torch.einsum('n,d->nd', seq_idx, theta) # Concatenate so that for row $m$ we have # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) - # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$ - neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.cos()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + self._build_cache(x) + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x) # Calculate # @@ -157,7 +176,7 @@ def forward(self, x: torch.Tensor): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :]) + rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # return rx diff --git a/labml_nn/transformers/rope/experiment.py b/labml_nn/transformers/rope/experiment.py index 24c80f9c..3ebc65c9 100644 --- a/labml_nn/transformers/rope/experiment.py +++ b/labml_nn/transformers/rope/experiment.py @@ -43,7 +43,7 @@ def _model(c: Configs): def main(): # Create experiment - experiment.create(name="rotary_pe_transformer") + experiment.create(name="rotary_pe_transformer", writers={'screen'}) # Create configs conf = Configs() # Override configurations diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py new file mode 100644 index 00000000..1305e222 --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -0,0 +1,163 @@ +""" +--- +title: Rotary Positional Embeddings (RoPE) +summary: > + Annotated implementation of RoPE from paper + RoFormer: Enhanced Transformer with Rotary Position Embedding +--- + +# Rotary Positional Embeddings (RoPE) + +This is an implementation of +[Rotary Positional Embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864) +in [PyTorch](https://pytorch.org). + +Rotary Positional Embeddings (RoPE) encode position information of tokens +with a rotation matrix that naturally incorporates explicit relative position +dependency. + +Here's [the training code](experiment.html) for training a transformer model with RoPE + on Tiny Shakespeare dataset. + +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe) +""" +from typing import Optional + +import torch + +from labml.logger import inspect +from labml_nn.transformers.mha import MultiHeadAttention +from labml_nn.transformers.rope import RotaryPositionalEmbeddings + + +class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings): + """ + ## RoPE module + """ + + def __init__(self, d: int, base: int = 10_000): + """ + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__(d, base) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + self._build_cache(x) + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x) + + # Calculate + # + # \begin{align} + # \begin{pmatrix} + # x^{(i)}_m \cos -m \theta_i - x^{(i + \frac{d}{2})}_m \sin -m \theta_i \\ + # x^{(i + \frac{d}{2})}_m \cos -m\theta_i + x^{(i)}_m \sin -m \theta_i \\ + # \end{pmatrix} = \\ + # \begin{pmatrix} + # x^{(i)}_m \cos m \theta_i + x^{(i + \frac{d}{2})}_m \sin m \theta_i \\ + # x^{(i + \frac{d}{2})}_m \cos m\theta_i - x^{(i)}_m \sin m \theta_i \\ + # \end{pmatrix} \\ + # \end{align} + # + # for $i \in {1, 2, ..., \frac{d}{2}}$ + rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) + + # + return rx + + +class RotaryValuePEMultiHeadAttention(MultiHeadAttention): + """ + ## Multi-head attention with rotary positional embeddings + + We override [multi-head attention from original transformer](../mha.html). + """ + + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + # The linear transformations do not need a bias since we + # explicitly include it when calculating scores. + # However having a bias for `value` might make sense. + super().__init__(heads, d_model, dropout_prob, bias=False) + + # Rotary positional embedding layers + self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) + self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k) + self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k) + + def get_scores(self, query: torch.Tensor, key: torch.Tensor): + """ + ### Calculate scores between queries and keys + """ + + # Calculate dot-product with RoPE + return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key)) + + def forward(self, *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None): + """ + `query`, `key` and `value` are the tensors that store + collection of *query*, *key* and *value* vectors. + They have shape `[seq_len, batch_size, d_model]`. + + `mask` has shape `[seq_len, seq_len, batch_size]` and + `mask[i, j, b]` indicates whether for batch `b`, + query at position `i` has access to key-value at position `j`. + """ + + # `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]` + seq_len, batch_size, _ = query.shape + + if mask is not None: + mask = self.prepare_mask(mask, query.shape, key.shape) + + # Prepare `query`, `key` and `value` for attention computation. + # These will then have shape `[seq_len, batch_size, heads, d_k]`. + query = self.query(query) + key = self.key(key) + value = self.value(value) + + # Compute attention scores $Q K^\top$. + # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`. + scores = self.get_scores(query, key) + + # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ + scores *= self.scale + + # Apply mask + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + + # $softmax$ attention along the key sequence dimension + # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ + attn = self.softmax(scores) + + # Apply dropout + attn = self.dropout(attn) + + # Rotate value embeddings before taking the weighted sum so that they contain positional information + value = self.value_rotary_pe(value) + + # Multiply by values + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ + x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value)) + + # Rotate in the opposite direction so that each embedding hold the relative positions + x = self.value_reverse_rotary_pe(x) + + # Save attentions for any other calculations + self.attn = attn.detach() + + # Concatenate multiple heads + x = x.reshape(seq_len, batch_size, -1) + + # Output layer + return self.output(x) diff --git a/labml_nn/transformers/rope/value_pe/experiment.py b/labml_nn/transformers/rope/value_pe/experiment.py new file mode 100644 index 00000000..4b29658c --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/experiment.py @@ -0,0 +1,95 @@ +""" +--- +title: Rotary Positional Embeddings (RoPE) Experiment +summary: This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset. +--- + +# Rotary Positional Embeddings (RoPE) Experiment + +This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE). + +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe) +""" + +from labml import experiment +from labml.configs import calculate +from labml_nn.transformers import TransformerConfigs +from labml_nn.transformers.rope.experiment import Configs + + +# ### Rotary PE attention + + +def _rotary_value_pe_mha(c: TransformerConfigs): + from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model) + + +# Configuration options +calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha) + + +def main(): + # Create experiment + experiment.create(name="rotary_pe_transformer", writers={'screen'}) + # Create configs + conf = Configs() + # Override configurations + experiment.configs(conf, { + # No fixed positional embeddings + 'transformer.src_embed': 'no_pos', + 'transformer.tgt_embed': 'no_pos', + + # Encoder with RoPE + 'transformer.encoder_attn': 'rotary_value', + # 'transformer.encoder_attn': 'rotary_value', + + # + 'model': 'rotary_pe_transformer', + + # Use character level tokenizer + 'tokenizer': 'character', + # Prompt separator is blank + 'prompt_separator': '', + # Starting prompt for sampling + 'prompt': 'It is ', + # Use Tiny Shakespeare dataset + 'text': 'tiny_shakespeare', + + # Use a context size of $256$ + 'seq_len': 512, + # Train for 32 epochs + 'epochs': 32, + # Batch size $4$ + 'batch_size': 4, + # Switch between training and validation for $10$ times + # per epoch + 'inner_iterations': 10, + + # Model size + 'd_model': 128, + 'transformer.ffn.d_ff': 512, + 'transformer.n_heads': 16, + 'transformer.dropout': 0.0, + + # Use [Noam optimizer](../../optimizers/noam.html) + 'optimizer.optimizer': 'Noam', + 'optimizer.learning_rate': 1., + + 'dataloader_shuffle_with_replacement': True + }) + + # Set models for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # Run training + conf.run() + + +# +if __name__ == '__main__': + main() From 10f44d0b21da65e2d3da51bb57a368e9f54ceb82 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 11:51:03 +0530 Subject: [PATCH 03/27] partial rope embeddings --- .labml.yaml | 1 + labml_nn/transformers/rope/__init__.py | 15 +++++++++------ labml_nn/transformers/rope/experiment.py | 2 +- .../transformers/rope/value_pe/__init__.py | 19 +++++++++++-------- .../transformers/rope/value_pe/experiment.py | 11 +++++++---- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/.labml.yaml b/.labml.yaml index 1290b7bf..5578d582 100644 --- a/.labml.yaml +++ b/.labml.yaml @@ -19,3 +19,4 @@ indicators: name: optim.* options: comet: false +web_api: http://localhost:5000/api/v1/track? diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index bbbc5aa1..2afd077b 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -163,8 +163,10 @@ def forward(self, x: torch.Tensor): """ self._build_cache(x) + x_rope, x_pass = x[..., :self.d], x[..., self.d:] + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x) + neg_half_x = self._neg_half(x_rope) # Calculate # @@ -176,10 +178,10 @@ def forward(self, x: torch.Tensor): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # - return rx + return torch.cat((x_rope, x_pass), dim=-1) class RotaryPEMultiHeadAttention(MultiHeadAttention): @@ -189,15 +191,16 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention): We override [multi-head attention from original transformer](../mha.html). """ - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1): # The linear transformations do not need a bias since we # explicitly include it when calculating scores. # However having a bias for `value` might make sense. super().__init__(heads, d_model, dropout_prob, bias=False) # Rotary positional embedding layers - self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) + d_rope = int(self.d_k * rope_percentage) + self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope) def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ diff --git a/labml_nn/transformers/rope/experiment.py b/labml_nn/transformers/rope/experiment.py index 3ebc65c9..b9aa3c5e 100644 --- a/labml_nn/transformers/rope/experiment.py +++ b/labml_nn/transformers/rope/experiment.py @@ -20,7 +20,7 @@ # ### Rotary PE attention def _rotary_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope import RotaryPEMultiHeadAttention - return RotaryPEMultiHeadAttention(c.n_heads, c.d_model) + return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5) # Configuration options diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index 1305e222..b2064dc1 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -48,8 +48,10 @@ def forward(self, x: torch.Tensor): """ self._build_cache(x) + x_rope, x_pass = x[..., :self.d], x[..., self.d:] + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x) + neg_half_x = self._neg_half(x_rope) # Calculate # @@ -65,10 +67,10 @@ def forward(self, x: torch.Tensor): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) # - return rx + return torch.cat((x_rope, x_pass), dim=-1) class RotaryValuePEMultiHeadAttention(MultiHeadAttention): @@ -78,17 +80,18 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention): We override [multi-head attention from original transformer](../mha.html). """ - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1): # The linear transformations do not need a bias since we # explicitly include it when calculating scores. # However having a bias for `value` might make sense. super().__init__(heads, d_model, dropout_prob, bias=False) # Rotary positional embedding layers - self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k) + d_rope = int(self.d_k * rope_percentage) + self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope) def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ diff --git a/labml_nn/transformers/rope/value_pe/experiment.py b/labml_nn/transformers/rope/value_pe/experiment.py index 4b29658c..791efbed 100644 --- a/labml_nn/transformers/rope/value_pe/experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiment.py @@ -13,16 +13,19 @@ from labml import experiment from labml.configs import calculate +from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression from labml_nn.transformers import TransformerConfigs -from labml_nn.transformers.rope.experiment import Configs +from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs # ### Rotary PE attention +class Configs(RoPEConfigs): # , ArithmeticAutoregression): + pass def _rotary_value_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention - return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model) + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5) # Configuration options @@ -33,7 +36,7 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="rotary_pe_transformer", writers={'screen'}) + experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'}) # Create configs conf = Configs() # Override configurations @@ -43,8 +46,8 @@ def main(): 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - 'transformer.encoder_attn': 'rotary_value', # 'transformer.encoder_attn': 'rotary_value', + 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', From 13686c1d282a53450d19fc2a47976327f85a1202 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 18:07:54 +0530 Subject: [PATCH 04/27] arithmetic exp --- .labml.yaml | 2 +- labml_nn/experiments/arithmetic_dataset.py | 32 +++++++ labml_nn/transformers/rope/experiment.py | 2 +- .../transformers/rope/value_pe/__init__.py | 11 ++- .../rope/value_pe/arithmetic_experiment.py | 95 +++++++++++++++++++ .../transformers/rope/value_pe/experiment.py | 20 ++-- 6 files changed, 146 insertions(+), 16 deletions(-) create mode 100644 labml_nn/transformers/rope/value_pe/arithmetic_experiment.py diff --git a/.labml.yaml b/.labml.yaml index 5578d582..2e384051 100644 --- a/.labml.yaml +++ b/.labml.yaml @@ -19,4 +19,4 @@ indicators: name: optim.* options: comet: false -web_api: http://localhost:5000/api/v1/track? +web_api: http://localhost:5005/api/v1/track? diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py index c4b7d3f0..2622b4c7 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -9,7 +9,9 @@ import torch from torch.utils.data import DataLoader, Dataset +from labml import monit, logger from labml.configs import option +from labml.logger import Text from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch @@ -82,6 +84,36 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs): train_loader: DataLoader = 'arithmetic_train_loader' valid_loader: DataLoader = 'arithmetic_valid_loader' + n_tokens = len(ArithmeticDataset(1, 1, 1).itos) + + def sample(self): + """ + ### Sampling function to generate samples periodically while training + """ + + # Starting prompt + prompt = self.prompt + # Collect output for printing + log = [(prompt, Text.subtle)] + # Dataset for decoding + dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1) + # Sample 25 tokens + for i in monit.iterate('Sample', self.seq_len - len(prompt)): + # Tokenize the prompt + data = torch.tensor(dataset.encode(prompt))[:, None] + data = data.to(self.device) + # Get the model output + output, *_ = self.model(data) + # Get the model prediction (greedy) + output = output.argmax(dim=-1).squeeze() + # Add the prediction to prompt + prompt += self.prompt_separator + dataset.itos[output[-1]] + # Add the prediction for logging + log += [(self.prompt_separator + dataset.itos[output[-1]], Text.value)] + + # Print the sampled output + logger.log(log) + @option(ArithmeticAutoregression.train_loader) def arithmetic_train_loader(c: ArithmeticAutoregression): diff --git a/labml_nn/transformers/rope/experiment.py b/labml_nn/transformers/rope/experiment.py index b9aa3c5e..9f63e6e4 100644 --- a/labml_nn/transformers/rope/experiment.py +++ b/labml_nn/transformers/rope/experiment.py @@ -20,7 +20,7 @@ # ### Rotary PE attention def _rotary_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope import RotaryPEMultiHeadAttention - return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5) + return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.) # Configuration options diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index b2064dc1..335bf030 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -25,7 +25,6 @@ import torch -from labml.logger import inspect from labml_nn.transformers.mha import MultiHeadAttention from labml_nn.transformers.rope import RotaryPositionalEmbeddings @@ -80,7 +79,9 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention): We override [multi-head attention from original transformer](../mha.html). """ - def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, + rope_percentage: float = 0.5, rope_value_percentage: float = 0.5, + dropout_prob: float = 0.1): # The linear transformations do not need a bias since we # explicitly include it when calculating scores. # However having a bias for `value` might make sense. @@ -88,10 +89,12 @@ def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropo # Rotary positional embedding layers d_rope = int(self.d_k * rope_percentage) + d_rope_value = int(self.d_k * rope_value_percentage) + self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope) self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope) - self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope) - self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope) + self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value) + self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value) def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py new file mode 100644 index 00000000..75dda880 --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -0,0 +1,95 @@ +""" +--- +title: Rotary Positional Embeddings (RoPE) Experiment +summary: This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset. +--- + +# Rotary Positional Embeddings (RoPE) Experiment + +This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE). + +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe) +""" + +from labml import experiment +from labml.configs import calculate +from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression +from labml_nn.transformers import TransformerConfigs +from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs + + +# ### Rotary PE attention + +class Configs(RoPEConfigs, ArithmeticAutoregression): # , ArithmeticAutoregression): + pass + + +def _rotary_value_pe_mha(c: TransformerConfigs): + from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5) + + +# Configuration options +calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha) + + +def main(): + # Create experiment + experiment.create(name="rope_arithmetic", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'}) + # Create configs + conf = Configs() + # Override configurations + experiment.configs(conf, { + # No fixed positional embeddings + 'transformer.src_embed': 'no_pos', + 'transformer.tgt_embed': 'no_pos', + + # Encoder with RoPE + # 'transformer.encoder_attn': 'rotary_value', + 'transformer.encoder_attn': 'rotary', + + # + 'model': 'rotary_pe_transformer', + + # Prompt separator is blank + 'prompt_separator': '', + # Starting prompt for sampling + 'prompt': '?x=2345+998;', + + # Use a context size of $256$ + 'seq_len': 128, + # Train for 32 epochs + 'epochs': 32, + # Batch size $4$ + 'batch_size': 4, + # Switch between training and validation for $10$ times + # per epoch + 'inner_iterations': 10, + + # Model size + 'd_model': 256, + 'transformer.ffn.d_ff': 1024, + 'transformer.n_heads': 8, + 'transformer.dropout': 0.0, + + # Use [Noam optimizer](../../optimizers/noam.html) + 'optimizer.optimizer': 'Noam', + 'optimizer.learning_rate': 1., + + 'dataloader_shuffle_with_replacement': True + }) + + # Set models for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # Run training + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/transformers/rope/value_pe/experiment.py b/labml_nn/transformers/rope/value_pe/experiment.py index 791efbed..5fbdfbed 100644 --- a/labml_nn/transformers/rope/value_pe/experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiment.py @@ -13,19 +13,19 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs # ### Rotary PE attention -class Configs(RoPEConfigs): # , ArithmeticAutoregression): +class Configs(RoPEConfigs): # , ArithmeticAutoregression): pass + def _rotary_value_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention - return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5) + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5) # Configuration options @@ -36,7 +36,7 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'}) + experiment.create(name="rotary_pe_transformer", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'}) # Create configs conf = Configs() # Override configurations @@ -46,8 +46,8 @@ def main(): 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - # 'transformer.encoder_attn': 'rotary_value', - 'transformer.encoder_attn': 'rotary', + 'transformer.encoder_attn': 'rotary_value', + # 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', @@ -62,7 +62,7 @@ def main(): 'text': 'tiny_shakespeare', # Use a context size of $256$ - 'seq_len': 512, + 'seq_len': 128, # Train for 32 epochs 'epochs': 32, # Batch size $4$ @@ -72,9 +72,9 @@ def main(): 'inner_iterations': 10, # Model size - 'd_model': 128, - 'transformer.ffn.d_ff': 512, - 'transformer.n_heads': 16, + 'd_model': 256, + 'transformer.ffn.d_ff': 1024, + 'transformer.n_heads': 8, 'transformer.dropout': 0.0, # Use [Noam optimizer](../../optimizers/noam.html) From c08af45b0309c5c9a2b82f9d9c73f1ef31e8c5f6 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 22:39:58 +0530 Subject: [PATCH 05/27] experiment --- labml_nn/experiments/arithmetic_dataset.py | 24 ++++++++++++------- .../rope/value_pe/arithmetic_experiment.py | 22 ++++++++--------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py index 2622b4c7..e75335f2 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -41,7 +41,7 @@ def get_add_explanation(x, y): rx, ry = x % 10, y % 10 total = rx + ry + carry explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}") - x, y, c = x // 10, y // 10, total // 10 + x, y, carry = x // 10, y // 10, total // 10 e += 1 return ' '.join(explanation) @@ -51,11 +51,13 @@ def make_add_problem(self): x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) - explanation = self.get_add_explanation(x, y) - return f"x={x}+{y}; {explanation} x=={x + y}\n" + if random.randrange(0, 5) < 1: + return f"x={x}+{y}; x=={x + y}\n" + else: + explanation = self.get_add_explanation(x, y) + return f"x={x}+{y}; {explanation} x=={x + y}\n" def get_packed_math_input(self): - s = "" s_enc = [] while len(s_enc) <= self.seq_len: s_part = self.make_add_problem() @@ -79,8 +81,8 @@ def __len__(self): class ArithmeticAutoregression(NLPAutoRegressionConfigs): max_digits: int = 4 - train_sequences_per_epoch: int = 1024 - valid_sequences_per_epoch: int = 128 + train_sequences_per_epoch: int = 2 ** 14 + valid_sequences_per_epoch: int = 2 ** 4 train_loader: DataLoader = 'arithmetic_train_loader' valid_loader: DataLoader = 'arithmetic_valid_loader' @@ -106,6 +108,10 @@ def sample(self): output, *_ = self.model(data) # Get the model prediction (greedy) output = output.argmax(dim=-1).squeeze() + + if dataset.itos[output[-1]] == '\n': + break + # Add the prediction to prompt prompt += self.prompt_separator + dataset.itos[output[-1]] # Add the prediction for logging @@ -119,14 +125,16 @@ def sample(self): def arithmetic_train_loader(c: ArithmeticAutoregression): return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch), batch_size=c.batch_size, - collate_fn=transpose_batch) + collate_fn=transpose_batch, + num_workers=4) @option(ArithmeticAutoregression.valid_loader) def arithmetic_valid_loader(c: ArithmeticAutoregression): return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch), batch_size=c.batch_size, - collate_fn=transpose_batch) + collate_fn=transpose_batch, + num_workers=4) def _test(): diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index 75dda880..ef4d3c0e 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -37,18 +37,20 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="rope_arithmetic", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'}) + experiment.create(name="rope_arithmetic", comment="rotary_value 1.0", writers={'screen', 'labml'}) # Create configs conf = Configs() # Override configurations experiment.configs(conf, { + 'max_digits': 9, + # No fixed positional embeddings 'transformer.src_embed': 'no_pos', 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - # 'transformer.encoder_attn': 'rotary_value', - 'transformer.encoder_attn': 'rotary', + 'transformer.encoder_attn': 'rotary_value', + # 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', @@ -56,29 +58,27 @@ def main(): # Prompt separator is blank 'prompt_separator': '', # Starting prompt for sampling - 'prompt': '?x=2345+998;', + 'prompt': '?x=123456789+1091919;', # Use a context size of $256$ - 'seq_len': 128, + 'seq_len': 512, # Train for 32 epochs 'epochs': 32, # Batch size $4$ - 'batch_size': 4, + 'batch_size': 16, # Switch between training and validation for $10$ times # per epoch 'inner_iterations': 10, # Model size - 'd_model': 256, - 'transformer.ffn.d_ff': 1024, - 'transformer.n_heads': 8, + 'd_model': 128, + 'transformer.ffn.d_ff': 512, + 'transformer.n_heads': 4, 'transformer.dropout': 0.0, # Use [Noam optimizer](../../optimizers/noam.html) 'optimizer.optimizer': 'Noam', 'optimizer.learning_rate': 1., - - 'dataloader_shuffle_with_replacement': True }) # Set models for saving and loading From e409e9bf98d214347e57ccf67e7e2c8089625b97 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 1 Jun 2022 14:07:27 +0530 Subject: [PATCH 06/27] arthmetic test score --- labml_nn/experiments/arithmetic_dataset.py | 76 ++++++++++++------- .../rope/value_pe/arithmetic_experiment.py | 9 +-- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py index e75335f2..d7f93665 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -9,9 +9,8 @@ import torch from torch.utils.data import DataLoader, Dataset -from labml import monit, logger +from labml import monit, logger, tracker from labml.configs import option -from labml.logger import Text from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch @@ -57,6 +56,12 @@ def make_add_problem(self): explanation = self.get_add_explanation(x, y) return f"x={x}+{y}; {explanation} x=={x + y}\n" + def get_qa(self): + x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + + return f'x={x}+{y};', f'{x + y}' + def get_packed_math_input(self): s_enc = [] while len(s_enc) <= self.seq_len: @@ -81,10 +86,11 @@ def __len__(self): class ArithmeticAutoregression(NLPAutoRegressionConfigs): max_digits: int = 4 - train_sequences_per_epoch: int = 2 ** 14 - valid_sequences_per_epoch: int = 2 ** 4 + train_sequences_per_epoch: int = 2 ** 12 train_loader: DataLoader = 'arithmetic_train_loader' - valid_loader: DataLoader = 'arithmetic_valid_loader' + n_tests: int = 32 + validator = None + inner_iterations = 4 n_tokens = len(ArithmeticDataset(1, 1, 1).itos) @@ -93,32 +99,52 @@ def sample(self): ### Sampling function to generate samples periodically while training """ - # Starting prompt - prompt = self.prompt - # Collect output for printing - log = [(prompt, Text.subtle)] - # Dataset for decoding + if self.training_loop.idx < 1: + return + dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1) + qa = [dataset.get_qa() for _ in range(self.n_tests)] + prompt = [p[0] for p in qa] + + data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]]) + data = data.to(self.device) + + finished = torch.zeros((len(prompt),)).bool().to(self.device) + new_line = dataset.stoi['\n'] + + results = [p[0] for p in prompt] + # Sample 25 tokens - for i in monit.iterate('Sample', self.seq_len - len(prompt)): + for i in monit.iterate('Sample', self.seq_len - 1): # Tokenize the prompt - data = torch.tensor(dataset.encode(prompt))[:, None] - data = data.to(self.device) # Get the model output output, *_ = self.model(data) # Get the model prediction (greedy) - output = output.argmax(dim=-1).squeeze() + output = output[-1].argmax(dim=-1) - if dataset.itos[output[-1]] == '\n': + finished = finished | (output == new_line) + if finished.sum() == len(finished): break - # Add the prediction to prompt - prompt += self.prompt_separator + dataset.itos[output[-1]] - # Add the prediction for logging - log += [(self.prompt_separator + dataset.itos[output[-1]], Text.value)] + for j, p in enumerate(prompt): + if len(p) > i + 1: + output[j] = dataset.stoi[p[i + 1]] - # Print the sampled output - logger.log(log) + data = torch.cat([data, output[None, :]], dim=0) + + for j, c in enumerate(output): + results[j] += dataset.itos[c] + + results = [r.split('\n')[0] for r in results] + logger.log(results[0]) + results = [r.split('x==')[-1] for r in results] + + correct = 0 + for r, _qa in zip(results, qa): + if r == _qa[1]: + correct += 1 + + tracker.save('score', correct / len(results)) @option(ArithmeticAutoregression.train_loader) @@ -129,14 +155,6 @@ def arithmetic_train_loader(c: ArithmeticAutoregression): num_workers=4) -@option(ArithmeticAutoregression.valid_loader) -def arithmetic_valid_loader(c: ArithmeticAutoregression): - return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch), - batch_size=c.batch_size, - collate_fn=transpose_batch, - num_workers=4) - - def _test(): dataset = ArithmeticDataset(256, 8, 10) diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index ef4d3c0e..2f65d211 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -26,7 +26,7 @@ class Configs(RoPEConfigs, ArithmeticAutoregression): # , ArithmeticAutoregress def _rotary_value_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention - return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5) + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.) # Configuration options @@ -42,7 +42,7 @@ def main(): conf = Configs() # Override configurations experiment.configs(conf, { - 'max_digits': 9, + 'max_digits': 6, # No fixed positional embeddings 'transformer.src_embed': 'no_pos', @@ -63,12 +63,9 @@ def main(): # Use a context size of $256$ 'seq_len': 512, # Train for 32 epochs - 'epochs': 32, + 'epochs': 64, # Batch size $4$ 'batch_size': 16, - # Switch between training and validation for $10$ times - # per epoch - 'inner_iterations': 10, # Model size 'd_model': 128, From bd5e9354e490187a951496159d5822d4b3bf04a3 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 1 Jun 2022 15:14:17 +0530 Subject: [PATCH 07/27] hp --- .../transformers/rope/value_pe/arithmetic_experiment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index 2f65d211..196cfc96 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -37,7 +37,7 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="rope_arithmetic", comment="rotary_value 1.0", writers={'screen', 'labml'}) + experiment.create(name="rope_arithmetic", comment="rotary 1.0", writers={'screen', 'labml'}) # Create configs conf = Configs() # Override configurations @@ -49,8 +49,8 @@ def main(): 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - 'transformer.encoder_attn': 'rotary_value', - # 'transformer.encoder_attn': 'rotary', + # 'transformer.encoder_attn': 'rotary_value', + 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', From 104070806e1707b1ebc50c3aa7b2c8e7b2e2ffcd Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 2 Jun 2022 16:15:43 +0530 Subject: [PATCH 08/27] logs --- labml_nn/experiments/arithmetic_dataset.py | 11 +++++++++-- labml_nn/transformers/rope/value_pe/__init__.py | 3 ++- .../rope/value_pe/arithmetic_experiment.py | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py index d7f93665..a1e38364 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -7,6 +7,7 @@ from typing import List import torch +from labml.logger import Text from torch.utils.data import DataLoader, Dataset from labml import monit, logger, tracker @@ -116,6 +117,9 @@ def sample(self): # Sample 25 tokens for i in monit.iterate('Sample', self.seq_len - 1): + if finished.sum() == len(finished): + continue + # Tokenize the prompt # Get the model output output, *_ = self.model(data) @@ -124,7 +128,7 @@ def sample(self): finished = finished | (output == new_line) if finished.sum() == len(finished): - break + continue for j, p in enumerate(prompt): if len(p) > i + 1: @@ -136,7 +140,10 @@ def sample(self): results[j] += dataset.itos[c] results = [r.split('\n')[0] for r in results] - logger.log(results[0]) + + res_sample = results[0].split(';') + logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)]) + results = [r.split('x==')[-1] for r in results] correct = 0 diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index 335bf030..96eb3561 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -21,6 +21,7 @@ [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe) """ + from typing import Optional import torch @@ -32,7 +33,7 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings): """ ## RoPE module - """ + """ def __init__(self, d: int, base: int = 10_000): """ diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index 196cfc96..6975f4a6 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -37,7 +37,7 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="rope_arithmetic", comment="rotary 1.0", writers={'screen', 'labml'}) + experiment.create(name="roper_addition", comment="rotary", writers={'screen', 'labml', 'comet'}) # Create configs conf = Configs() # Override configurations @@ -63,7 +63,7 @@ def main(): # Use a context size of $256$ 'seq_len': 512, # Train for 32 epochs - 'epochs': 64, + 'epochs': 50, # Batch size $4$ 'batch_size': 16, From ccf6b6bd25fe9001abd4e5b316738a53fba03aef Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 2 Jun 2022 18:14:37 +0530 Subject: [PATCH 09/27] rotary value --- .../transformers/rope/value_pe/arithmetic_experiment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index 6975f4a6..d712162d 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -37,7 +37,7 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="roper_addition", comment="rotary", writers={'screen', 'labml', 'comet'}) + experiment.create(name="roper_addition", comment="rotary value", writers={'screen', 'labml', 'comet'}) # Create configs conf = Configs() # Override configurations @@ -49,8 +49,8 @@ def main(): 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - # 'transformer.encoder_attn': 'rotary_value', - 'transformer.encoder_attn': 'rotary', + 'transformer.encoder_attn': 'rotary_value', + # 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', From 669b920d6a82cd5dab1add9448dc2f3b6a6ad5f7 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 3 Jun 2022 10:10:35 +0530 Subject: [PATCH 10/27] roper docs --- docs/experiments/arithmetic_dataset.html | 440 ++++++++++++++ docs/normalization/deep_norm/experiment.html | 3 +- docs/sitemap.xml | 34 +- docs/transformers/rope/experiment.html | 4 +- docs/transformers/rope/index.html | 229 +++++--- .../rope/value_pe/arithmetic_experiment.html | 418 +++++++++++++ .../rope/value_pe/experiment.html | 454 +++++++++++++++ docs/transformers/rope/value_pe/index.html | 549 ++++++++++++++++++ labml_nn/transformers/rope/__init__.py | 8 + .../transformers/rope/value_pe/__init__.py | 130 ++++- setup.py | 6 +- 11 files changed, 2176 insertions(+), 99 deletions(-) create mode 100644 docs/experiments/arithmetic_dataset.html create mode 100644 docs/transformers/rope/value_pe/arithmetic_experiment.html create mode 100644 docs/transformers/rope/value_pe/experiment.html create mode 100644 docs/transformers/rope/value_pe/index.html diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html new file mode 100644 index 00000000..5a9a2d54 --- /dev/null +++ b/docs/experiments/arithmetic_dataset.html @@ -0,0 +1,440 @@ + + + + + + + + + + + + + + + + + + + + + + + arithmetic_dataset.py + + + + + + + + + + +
+
+
+
+

+ home + experiments +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

This is based on code by @gharik.

+ +
+
+
5import random
+6import string
+7from typing import List
+8
+9import torch
+10from labml.logger import Text
+11from torch.utils.data import DataLoader, Dataset
+12
+13from labml import monit, logger, tracker
+14from labml.configs import option
+15from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
+
+
+
+
+ + +
+
+
18class ArithmeticDataset(Dataset):
+
+
+
+
+ + +
+
+
19    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
+20        self.n_sequences = n_sequences
+21        self.max_digits = max_digits
+22        self.seq_len = seq_len
+23        self.itos = list(string.digits + 'xe =\n?+;')
+24        self.stoi = {c: i for i, c in enumerate(self.itos)}
+
+
+
+
+ + +
+
+
26    @staticmethod
+27    def make_int(n_digits):
+28        res = 0
+29        for i in range(n_digits):
+30            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
+31            res = res * 10 + d
+32
+33        return res
+34
+35    @staticmethod
+36    def get_add_explanation(x, y):
+37        carry = 0
+38        e = 0
+39        explanation = []
+40        while x > 0 or y > 0 or carry > 0:
+41            rx, ry = x % 10, y % 10
+42            total = rx + ry + carry
+43            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
+44            x, y, carry = x // 10, y // 10, total // 10
+45            e += 1
+46
+47        return ' '.join(explanation)
+
+
+
+
+ +

Make a problem with a pre_explanation or not

+ +
+
+
50    def make_add_problem(self):
+51        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+52        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+53
+54        if random.randrange(0, 5) < 1:
+55            return f"x={x}+{y}; x=={x + y}\n"
+56        else:
+57            explanation = self.get_add_explanation(x, y)
+58            return f"x={x}+{y}; {explanation} x=={x + y}\n"
+
+
+
+
+ + +
+
+
60    def get_qa(self):
+61        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+62        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+63
+64        return f'x={x}+{y};', f'{x + y}'
+
+
+
+
+ + +
+
+
66    def get_packed_math_input(self):
+67        s_enc = []
+68        while len(s_enc) <= self.seq_len:
+69            s_part = self.make_add_problem()
+70            s_part_enc = self.encode('?' + s_part)
+71            s_enc = s_enc + s_part_enc
+72        return s_enc
+
+
+
+
+ + +
+
+
74    def encode(self, s: str):
+75        return [self.stoi[c] for c in s]
+
+
+
+
+ + +
+
+
77    def decode(self, arr: List[int]):
+78        return ''.join([self.itos[c] for c in arr])
+
+
+
+
+ + +
+
+
80    def __getitem__(self, idx):
+81        s = torch.tensor(self.get_packed_math_input())
+82        return s[:self.seq_len], s[1:self.seq_len + 1]
+
+
+
+
+ + +
+
+
84    def __len__(self):
+85        return self.n_sequences
+
+
+
+
+ + +
+
+
88class ArithmeticAutoregression(NLPAutoRegressionConfigs):
+89    max_digits: int = 4
+90    train_sequences_per_epoch: int = 2 ** 12
+91    train_loader: DataLoader = 'arithmetic_train_loader'
+92    n_tests: int = 32
+93    validator = None
+94    inner_iterations = 4
+95
+96    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
+
+
+
+ +

Sampling function to generate samples periodically while training

+ +
+
+
98    def sample(self):
+
+
+
+
+ + +
+
+
103        if self.training_loop.idx < 1:
+104            return
+105
+106        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+107        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+108        prompt = [p[0] for p in qa]
+109
+110        data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]])
+111        data = data.to(self.device)
+112
+113        finished = torch.zeros((len(prompt),)).bool().to(self.device)
+114        new_line = dataset.stoi['\n']
+115
+116        results = [p[0] for p in prompt]
+
+
+
+
+ +

Sample 25 tokens

+ +
+
+
119        for i in monit.iterate('Sample', self.seq_len - 1):
+120            if finished.sum() == len(finished):
+121                continue
+
+
+
+
+ +

Tokenize the prompt Get the model output

+ +
+
+
125            output, *_ = self.model(data)
+
+
+
+
+ +

Get the model prediction (greedy)

+ +
+
+
127            output = output[-1].argmax(dim=-1)
+128
+129            finished = finished | (output == new_line)
+130            if finished.sum() == len(finished):
+131                continue
+132
+133            for j, p in enumerate(prompt):
+134                if len(p) > i + 1:
+135                    output[j] = dataset.stoi[p[i + 1]]
+136
+137            data = torch.cat([data, output[None, :]], dim=0)
+138
+139            for j, c in enumerate(output):
+140                results[j] += dataset.itos[c]
+141
+142        results = [r.split('\n')[0] for r in results]
+143
+144        res_sample = results[0].split(';')
+145        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+146
+147        results = [r.split('x==')[-1] for r in results]
+148
+149        correct = 0
+150        for r, _qa in zip(results, qa):
+151            if r == _qa[1]:
+152                correct += 1
+153
+154        tracker.save('score', correct / len(results))
+
+
+
+
+ + +
+
+
157@option(ArithmeticAutoregression.train_loader)
+158def arithmetic_train_loader(c: ArithmeticAutoregression):
+159    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+160                      batch_size=c.batch_size,
+161                      collate_fn=transpose_batch,
+162                      num_workers=4)
+163
+164
+165def _test():
+166    dataset = ArithmeticDataset(256, 8, 10)
+167
+168    print(dataset.decode(dataset.get_packed_math_input()))
+169
+170
+171if __name__ == '__main__':
+172    _test()
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/normalization/deep_norm/experiment.html b/docs/normalization/deep_norm/experiment.html index 9f7e0095..8ccf394a 100644 --- a/docs/normalization/deep_norm/experiment.html +++ b/docs/normalization/deep_norm/experiment.html @@ -70,7 +70,8 @@ #

DeepNorm Experiment

-

Open In Colab View Run

+

Open In Colab View Run Open In Comet

+
15import copy
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 42cf8562..8f6077e4 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -204,7 +204,7 @@
 
     
       https://nn.labml.ai/normalization/deep_norm/index.html
-      2022-04-23T16:30:00+00:00
+      2022-05-18T16:30:00+00:00
       1.00
     
     
@@ -244,6 +244,13 @@
     
     
 
+    
+      https://nn.labml.ai/experiments/arithmetic_dataset.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
     
       https://nn.labml.ai/experiments/index.html
       2020-12-26T16:30:00+00:00
@@ -603,14 +610,35 @@
 
     
       https://nn.labml.ai/transformers/rope/index.html
-      2022-04-05T16:30:00+00:00
+      2022-05-31T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/index.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/experiment.html
+      2022-05-31T16:30:00+00:00
       1.00
     
     
 
     
       https://nn.labml.ai/transformers/rope/experiment.html
-      2022-03-12T16:30:00+00:00
+      2022-05-31T16:30:00+00:00
       1.00
     
     
diff --git a/docs/transformers/rope/experiment.html b/docs/transformers/rope/experiment.html
index 156b2092..5ce9d242 100644
--- a/docs/transformers/rope/experiment.html
+++ b/docs/transformers/rope/experiment.html
@@ -92,7 +92,7 @@ 

Rotary PE attention

21def _rotary_pe_mha(c: TransformerConfigs):
 22    from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
-23    return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
+23 return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.)
@@ -157,7 +157,7 @@

Rotary PE attention

-
46    experiment.create(name="rotary_pe_transformer")
+
46    experiment.create(name="rotary_pe_transformer", writers={'screen'})
diff --git a/docs/transformers/rope/index.html b/docs/transformers/rope/index.html index 06512a47..27455c61 100644 --- a/docs/transformers/rope/index.html +++ b/docs/transformers/rope/index.html @@ -90,19 +90,19 @@

Rotary Positional Embeddings (RoPE)

#

RoPE module

-

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

+

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

For a pair of features

-

Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,

-

where is a constant angle. The other pairs of features are transformed similarly.

+

Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,

+

where is a constant angle. The other pairs of features are transformed similarly.

Attention is relative

-

For a pair of features, dot-product attention score between two positions and would be

-

This shows that for dot-production attention the rotary encodings gives relative attention.

+

For a pair of features, dot-product attention score between two positions and would be

+

This shows that for dot-production attention the rotary encodings gives relative attention.

For all features

The features are grouped into pairs and handled as above. They use a different for each pair.

-

The paper suggests using for the pairs of features.

-

We pair feature with feature . So for position we transform

-

to

- +

The paper suggests using for the pairs of features.

+

We pair feature with feature . So for position we transform

+

to

+
32class RotaryPositionalEmbeddings(nn.Module):
@@ -114,13 +114,13 @@

For all features

#
  • d - is the number of features
  • + is the number of features
  • base is the constant used for calculating
-
118    def __init__(self, d: int, base: int = 10_000):
+
119    def __init__(self, d: int, base: int = 10_000):
@@ -131,33 +131,37 @@

For all features

-
123        super().__init__()
+
124        super().__init__()
+125
+126        self.base = base
+127        self.d = d
+128        self.cos_cached = None
+129        self.sin_cached = None
-
+
-

+

Cache and values

-
125        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
+
131    def _build_cache(self, x: torch.Tensor):
-
+
-
  • x - is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d] -
+

Return if cache is already built

-
127    def forward(self, x: torch.Tensor):
+
136        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
+137            return
@@ -165,11 +169,11 @@

For all features

-

Extract the shape

+

Get sequence length

-
132        seq_len, batch_size, n_heads, d = x.shape
+
140        seq_len = x.shape[0]
@@ -177,11 +181,11 @@

For all features

-

+

-
135        d_2 = d // 2
+
143        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
@@ -194,7 +198,7 @@

For all features

-
138        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
+
146        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
@@ -202,11 +206,11 @@

For all features

-

Calculate the product of position index and

+

Calculate the product of position index and

-
141        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
+
149        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
@@ -214,11 +218,11 @@

For all features

-

Concatenate so that for row we have

+

Concatenate so that for row we have

-
145        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
+
153        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
@@ -226,11 +230,12 @@

For all features

-

Calculate

+

Cache them

-
148        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
+
156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
+157        self.sin_cached = idx_theta2.cos()[:, None, None, :]
@@ -238,12 +243,10 @@

For all features

-

Calculate

-

for

- +
-
160        rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :])
+
159    def _neg_half(self, x: torch.Tensor):
@@ -251,35 +254,37 @@

For all features

-

+

-
163        return rx
+
161        d_2 = self.d // 2
-
+
-

Multi-head attention with rotary positional embeddings

-

We override multi-head attention from original transformer.

+

Calculate

-
166class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
164        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
-
+
- +
  • x + is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d] +
+
-
173    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
+
166    def forward(self, x: torch.Tensor):
@@ -287,12 +292,11 @@

Multi-head attention with rotary positional embeddings

-

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value - might make sense.

+

Cache and values

-
177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
171        self._build_cache(x)
@@ -300,24 +304,23 @@

Multi-head attention with rotary positional embeddings

-

Rotary positional embedding layers

+

Split the features, we can choose to apply rotary embeddings only to a partial set of features.

-
180        self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
-181        self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
+
174        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
-
+
-

Calculate scores between queries and keys

+

Calculate

-
183    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
177        neg_half_x = self._neg_half(x_rope)
@@ -325,43 +328,131 @@

Calculate scores between queries and keys

-

Calculate dot-product with RoPE

+

Calculate

+

for

-
189        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
189        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
-
+
-

Testing RoPE with a simple example

+

-
192def _test_rotary():
+
192        return torch.cat((x_rope, x_pass), dim=-1)
-
+
+

Multi-head attention with rotary positional embeddings

+

We override multi-head attention from original transformer.

+ +
+
+
195class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
+
+
+
+ + +
+
+
202    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
+
+
+
+
+ +

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value + might make sense.

+ +
+
+
206        super().__init__(heads, d_model, dropout_prob, bias=False)
+
+
+
+
+ +

Rotary positional embedding layers

+ +
+
+
209        d_rope = int(self.d_k * rope_percentage)
+210        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+211        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+
+
+
+
+ +

Calculate scores between queries and keys

+ +
+
+
213    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
+
+
+
+ +

Calculate dot-product with RoPE

+ +
+
+
219        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
+
+
+
+ +

Testing RoPE with a simple example

+ +
+
+
222def _test_rotary():
+
+
+
+
+
-
196    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-197    x = x[:, None, None, :]
-198    inspect(x)
-199
-200    rotary_pe = RotaryPositionalEmbeddings(3)
-201    inspect(rotary_pe(x))
-202
-203
-204if __name__ == '__main__':
-205    _test_rotary()
+
226    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+227    x = x[:, None, None, :]
+228    inspect(x)
+229
+230    rotary_pe = RotaryPositionalEmbeddings(3)
+231    inspect(rotary_pe(x))
+232
+233
+234if __name__ == '__main__':
+235    _test_rotary()
156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
-157        self.sin_cached = idx_theta2.cos()[:, None, None, :]
+157 self.sin_cached = idx_theta2.sin()[:, None, None, :]
@@ -320,7 +320,7 @@

For all features

-
177        neg_half_x = self._neg_half(x_rope)
+
178        neg_half_x = self._neg_half(x_rope)
@@ -333,7 +333,7 @@

For all features

-
189        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
+
190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -345,7 +345,7 @@

For all features

-
192        return torch.cat((x_rope, x_pass), dim=-1)
+
193        return torch.cat((x_rope, x_pass), dim=-1)
@@ -358,7 +358,7 @@

Multi-head attention with rotary positional embeddings

-
195class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
196class RotaryPEMultiHeadAttention(MultiHeadAttention):
@@ -369,7 +369,7 @@

Multi-head attention with rotary positional embeddings

-
202    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
+
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
@@ -382,7 +382,7 @@

Multi-head attention with rotary positional embeddings

-
206        super().__init__(heads, d_model, dropout_prob, bias=False)
+
207        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -394,9 +394,9 @@

Multi-head attention with rotary positional embeddings

-
209        d_rope = int(self.d_k * rope_percentage)
-210        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-211        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+
210        d_rope = int(self.d_k * rope_percentage)
+211        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+212        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
@@ -408,7 +408,7 @@

Calculate scores between queries and keys

-
213    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
214    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -420,7 +420,7 @@

Calculate scores between queries and keys

-
219        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
220        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
@@ -432,7 +432,7 @@

Calculate scores between queries and keys

-
222def _test_rotary():
+
223def _test_rotary():
@@ -443,16 +443,16 @@

Calculate scores between queries and keys

-
226    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-227    x = x[:, None, None, :]
-228    inspect(x)
-229
-230    rotary_pe = RotaryPositionalEmbeddings(3)
-231    inspect(rotary_pe(x))
-232
+            
227    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+228    x = x[:, None, None, :]
+229    inspect(x)
+230
+231    rotary_pe = RotaryPositionalEmbeddings(3)
+232    inspect(rotary_pe(x))
 233
-234if __name__ == '__main__':
-235    _test_rotary()
+234 +235if __name__ == '__main__': +236 _test_rotary()
-
142        neg_half_x = self._neg_half(x_rope)
+
143        neg_half_x = self._neg_half(x_rope)
@@ -173,7 +173,7 @@

RoPE module that rotates in the opposite direction

-
158        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
+
159        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -185,7 +185,7 @@

RoPE module that rotates in the opposite direction

-
161        return torch.cat((x_rope, x_pass), dim=-1)
+
162        return torch.cat((x_rope, x_pass), dim=-1)
@@ -198,7 +198,7 @@

Multi-head attention with rotary positional embeddings

-
164class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
+
165class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
@@ -209,9 +209,9 @@

Multi-head attention with rotary positional embeddings

-
171    def __init__(self, heads: int, d_model: int,
-172                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-173                 dropout_prob: float = 0.1):
+
172    def __init__(self, heads: int, d_model: int,
+173                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+174                 dropout_prob: float = 0.1):
@@ -224,7 +224,7 @@

Multi-head attention with rotary positional embeddings

-
177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
178        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -236,13 +236,13 @@

Multi-head attention with rotary positional embeddings

-
180        d_rope = int(self.d_k * rope_percentage)
-181        d_rope_value = int(self.d_k * rope_value_percentage)
-182
-183        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-184        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-186        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
181        d_rope = int(self.d_k * rope_percentage)
+182        d_rope_value = int(self.d_k * rope_value_percentage)
+183
+184        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+185        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+186        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+187        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
@@ -254,7 +254,7 @@

Calculate scores between queries and keys

-
188    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
189    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -266,7 +266,7 @@

Calculate scores between queries and keys

-
194        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
195        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
@@ -289,11 +289,11 @@

Calculate scores between queries and keys

-
196    def forward(self, *,
-197                query: torch.Tensor,
-198                key: torch.Tensor,
-199                value: torch.Tensor,
-200                mask: Optional[torch.Tensor] = None):
+
197    def forward(self, *,
+198                query: torch.Tensor,
+199                key: torch.Tensor,
+200                value: torch.Tensor,
+201                mask: Optional[torch.Tensor] = None):
@@ -309,10 +309,10 @@

Calculate scores between queries and keys

-
212        seq_len, batch_size, _ = query.shape
-213
-214        if mask is not None:
-215            mask = self.prepare_mask(mask, query.shape, key.shape)
+
213        seq_len, batch_size, _ = query.shape
+214
+215        if mask is not None:
+216            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -328,9 +328,9 @@

Calculate scores between queries and keys

-
219        query = self.query(query)
-220        key = self.key(key)
-221        value = self.value(value)
+
220        query = self.query(query)
+221        key = self.key(key)
+222        value = self.value(value)
@@ -343,7 +343,7 @@

Calculate scores between queries and keys

-
225        scores = self.get_scores(query, key)
+
226        scores = self.get_scores(query, key)
@@ -366,7 +366,7 @@

Calculate scores between queries and keys

-
228        scores *= self.scale
+
229        scores *= self.scale
@@ -378,8 +378,8 @@

Calculate scores between queries and keys

-
231        if mask is not None:
-232            scores = scores.masked_fill(mask == 0, float('-inf'))
+
232        if mask is not None:
+233            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -402,7 +402,7 @@

Calculate scores between queries and keys

-
236        attn = self.softmax(scores)
+
237        attn = self.softmax(scores)
@@ -414,7 +414,7 @@

Calculate scores between queries and keys

-
239        attn = self.dropout(attn)
+
240        attn = self.dropout(attn)
@@ -426,7 +426,7 @@

Calculate scores between queries and keys

-
242        value = self.value_rotary_pe(value)
+
243        value = self.value_rotary_pe(value)
@@ -449,7 +449,7 @@

Calculate scores between queries and keys

-
246        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
247        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
@@ -461,7 +461,7 @@

Calculate scores between queries and keys

-
249        x = self.value_reverse_rotary_pe(x)
+
250        x = self.value_reverse_rotary_pe(x)
@@ -473,7 +473,7 @@

Calculate scores between queries and keys

-
252        self.attn = attn.detach()
+
253        self.attn = attn.detach()
@@ -485,7 +485,7 @@

Calculate scores between queries and keys

-
255        x = x.reshape(seq_len, batch_size, -1)
+
256        x = x.reshape(seq_len, batch_size, -1)
@@ -497,7 +497,7 @@

Calculate scores between queries and keys

-
258        return self.output(x)
+
259        return self.output(x)
-
5import random
-6import string
-7from typing import List
-8
-9import torch
-10from labml.logger import Text
-11from torch.utils.data import DataLoader, Dataset
-12
-13from labml import monit, logger, tracker
-14from labml.configs import option
-15from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
+
11import random
+12import string
+13from typing import List
+14
+15import torch
+16from labml.logger import Text
+17from torch.utils.data import DataLoader, Dataset
+18
+19from labml import monit, logger, tracker
+20from labml.configs import option
+21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
-
+
- +

Arithmetic Dataset

+

This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.

+

It's based on a character level tokenization.

+
-
18class ArithmeticDataset(Dataset):
+
24class ArithmeticDataset(Dataset):
-
+
- +
  • seq_len is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch
+
-
19    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
-20        self.n_sequences = n_sequences
-21        self.max_digits = max_digits
-22        self.seq_len = seq_len
-23        self.itos = list(string.digits + 'xe =\n?+;')
-24        self.stoi = {c: i for i, c in enumerate(self.itos)}
+
33    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
@@ -120,28 +119,9 @@
-
26    @staticmethod
-27    def make_int(n_digits):
-28        res = 0
-29        for i in range(n_digits):
-30            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
-31            res = res * 10 + d
-32
-33        return res
-34
-35    @staticmethod
-36    def get_add_explanation(x, y):
-37        carry = 0
-38        e = 0
-39        explanation = []
-40        while x > 0 or y > 0 or carry > 0:
-41            rx, ry = x % 10, y % 10
-42            total = rx + ry + carry
-43            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
-44            x, y, carry = x // 10, y // 10, total // 10
-45            e += 1
-46
-47        return ' '.join(explanation)
+
40        self.n_sequences = n_sequences
+41        self.max_digits = max_digits
+42        self.seq_len = seq_len
@@ -149,19 +129,11 @@ -

Make a problem with a pre_explanation or not

+

Token id to string

-
50    def make_add_problem(self):
-51        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-52        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-53
-54        if random.randrange(0, 5) < 1:
-55            return f"x={x}+{y}; x=={x + y}\n"
-56        else:
-57            explanation = self.get_add_explanation(x, y)
-58            return f"x={x}+{y}; {explanation} x=={x + y}\n"
+
44        self.itos = list(string.digits + 'xe =\n?+;')
@@ -169,31 +141,25 @@ - +

Character to token id

+
-
60    def get_qa(self):
-61        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-62        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-63
-64        return f'x={x}+{y};', f'{x + y}'
+
46        self.stoi = {c: i for i, c in enumerate(self.itos)}
-
+
- +

Generates an integer with n_digit + number of digits

+
-
66    def get_packed_math_input(self):
-67        s_enc = []
-68        while len(s_enc) <= self.seq_len:
-69            s_part = self.make_add_problem()
-70            s_part_enc = self.encode('?' + s_part)
-71            s_enc = s_enc + s_part_enc
-72        return s_enc
+
48    @staticmethod
+49    def make_int(n_digits: int):
@@ -204,20 +170,28 @@
-
74    def encode(self, s: str):
-75        return [self.stoi[c] for c in s]
+
53        res = 0
+54        for i in range(n_digits):
+55            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
+56            res = res * 10 + d
+57
+58        return res
-
+
- +

Generates the workings for x + y +. For example for 11+29 + it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0 +.

+
-
77    def decode(self, arr: List[int]):
-78        return ''.join([self.itos[c] for c in arr])
+
60    @staticmethod
+61    def get_add_explanation(x: int, y: int):
@@ -228,21 +202,30 @@
-
80    def __getitem__(self, idx):
-81        s = torch.tensor(self.get_packed_math_input())
-82        return s[:self.seq_len], s[1:self.seq_len + 1]
+
68        carry = 0
+69        e = 0
+70        explanation = []
+71        while x > 0 or y > 0 or carry > 0:
+72            rx, ry = x % 10, y % 10
+73            total = rx + ry + carry
+74            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
+75            x, y, carry = x // 10, y // 10, total // 10
+76            e += 1
+77
+78        return ' '.join(explanation)
-
+
- +

Make a problem with a pre_explanation or not

+

Creates an arithmetic addition problem with workings and answer.

+
-
84    def __len__(self):
-85        return self.n_sequences
+
81    def make_add_problem(self):
@@ -253,27 +236,26 @@
-
88class ArithmeticAutoregression(NLPAutoRegressionConfigs):
-89    max_digits: int = 4
-90    train_sequences_per_epoch: int = 2 ** 12
-91    train_loader: DataLoader = 'arithmetic_train_loader'
-92    n_tests: int = 32
-93    validator = None
-94    inner_iterations = 4
-95
-96    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
85        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+86        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+87
+88        explanation = self.get_add_explanation(x, y)
+89        return f"x={x}+{y}; {explanation} x=={x + y}\n"
-
+
-

Sampling function to generate samples periodically while training

- +
-
98    def sample(self):
+
91    def get_qa(self):
+92        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+93        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+94
+95        return f'x={x}+{y};', f'{x + y}'
@@ -284,20 +266,13 @@

Sampling function to generate samples periodically while training

-
103        if self.training_loop.idx < 1:
-104            return
-105
-106        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
-107        qa = [dataset.get_qa() for _ in range(self.n_tests)]
-108        prompt = [p[0] for p in qa]
-109
-110        data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]])
-111        data = data.to(self.device)
-112
-113        finished = torch.zeros((len(prompt),)).bool().to(self.device)
-114        new_line = dataset.stoi['\n']
-115
-116        results = [p[0] for p in prompt]
+
97    def get_packed_math_input(self):
+98        s_enc = []
+99        while len(s_enc) <= self.seq_len:
+100            s_part = self.make_add_problem()
+101            s_part_enc = self.encode('?' + s_part)
+102            s_enc = s_enc + s_part_enc
+103        return s_enc
@@ -305,13 +280,11 @@

Sampling function to generate samples periodically while training

-

Sample 25 tokens

- +
-
119        for i in monit.iterate('Sample', self.seq_len - 1):
-120            if finished.sum() == len(finished):
-121                continue
+
105    def encode(self, s: str):
+106        return [self.stoi[c] for c in s]
@@ -319,11 +292,11 @@

Sampling function to generate samples periodically while training

-

Tokenize the prompt Get the model output

- +
-
125            output, *_ = self.model(data)
+
108    def decode(self, arr: List[int]):
+109        return ''.join([self.itos[c] for c in arr])
@@ -331,64 +304,170 @@

Sampling function to generate samples periodically while training

-

Get the model prediction (greedy)

+ +
+
+
111    def __getitem__(self, idx):
+112        s = torch.tensor(self.get_packed_math_input())
+113        return s[:self.seq_len], s[1:self.seq_len + 1]
+
+ +
+
+ + +
+
+
115    def __len__(self):
+116        return self.n_sequences
+
+
+
+
+ + +
+
+
119class ArithmeticAutoregression(NLPAutoRegressionConfigs):
+120    max_digits: int = 4
+121    train_sequences_per_epoch: int = 2 ** 12
+122    train_loader: DataLoader = 'arithmetic_train_loader'
+123    n_tests: int = 32
+124    validator = None
+125    inner_iterations = 4
+126
+127    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
+
+
+
+ +

Sampling function to generate samples periodically while training

-
127            output = output[-1].argmax(dim=-1)
-128
-129            finished = finished | (output == new_line)
-130            if finished.sum() == len(finished):
-131                continue
-132
-133            for j, p in enumerate(prompt):
-134                if len(p) > i + 1:
-135                    output[j] = dataset.stoi[p[i + 1]]
+            
129    def sample(self):
+
+
+
+
+ + +
+
+
134        if self.training_loop.idx < 1:
+135            return
 136
-137            data = torch.cat([data, output[None, :]], dim=0)
-138
-139            for j, c in enumerate(output):
-140                results[j] += dataset.itos[c]
-141
-142        results = [r.split('\n')[0] for r in results]
+137        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+138        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+139        prompt = [p[0] for p in qa]
+140
+141        data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]])
+142        data = data.to(self.device)
 143
-144        res_sample = results[0].split(';')
-145        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+144        finished = torch.zeros((len(prompt),)).bool().to(self.device)
+145        new_line = dataset.stoi['\n']
 146
-147        results = [r.split('x==')[-1] for r in results]
-148
-149        correct = 0
-150        for r, _qa in zip(results, qa):
-151            if r == _qa[1]:
-152                correct += 1
-153
-154        tracker.save('score', correct / len(results))
+147 results = [p[0] for p in prompt]
-
+
- +

Sample 25 tokens

+ +
+
+
150        for i in monit.iterate('Sample', self.seq_len - 1):
+151            if finished.sum() == len(finished):
+152                continue
+
+
+
+
+ +

Tokenize the prompt Get the model output

+
-
157@option(ArithmeticAutoregression.train_loader)
-158def arithmetic_train_loader(c: ArithmeticAutoregression):
-159    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
-160                      batch_size=c.batch_size,
-161                      collate_fn=transpose_batch,
-162                      num_workers=4)
+            
156            output, *_ = self.model(data)
+
+
+
+
+ +

Get the model prediction (greedy)

+ +
+
+
158            output = output[-1].argmax(dim=-1)
+159
+160            finished = finished | (output == new_line)
+161            if finished.sum() == len(finished):
+162                continue
 163
-164
-165def _test():
-166    dataset = ArithmeticDataset(256, 8, 10)
+164            for j, p in enumerate(prompt):
+165                if len(p) > i + 1:
+166                    output[j] = dataset.stoi[p[i + 1]]
 167
-168    print(dataset.decode(dataset.get_packed_math_input()))
+168            data = torch.cat([data, output[None, :]], dim=0)
 169
-170
-171if __name__ == '__main__':
-172    _test()
+170 for j, c in enumerate(output): +171 results[j] += dataset.itos[c] +172 +173 results = [r.split('\n')[0] for r in results] +174 +175 res_sample = results[0].split(';') +176 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)]) +177 +178 results = [r.split('x==')[-1] for r in results] +179 +180 correct = 0 +181 for r, _qa in zip(results, qa): +182 if r == _qa[1]: +183 correct += 1 +184 +185 tracker.save('score', correct / len(results))
+
+
+
+
+ + +
+
+
188@option(ArithmeticAutoregression.train_loader)
+189def arithmetic_train_loader(c: ArithmeticAutoregression):
+190    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+191                      batch_size=c.batch_size,
+192                      collate_fn=transpose_batch,
+193                      num_workers=4)
+194
+195
+196def _test():
+197    dataset = ArithmeticDataset(256, 8, 10)
+198
+199    print(dataset.decode(dataset.get_packed_math_input()))
+200
+201
+202if __name__ == '__main__':
+203    _test()
-
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
+
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
diff --git a/docs/transformers/rope/value_pe/arithmetic_experiment.html b/docs/transformers/rope/value_pe/arithmetic_experiment.html index dc58bc05..45d9d634 100644 --- a/docs/transformers/rope/value_pe/arithmetic_experiment.html +++ b/docs/transformers/rope/value_pe/arithmetic_experiment.html @@ -3,24 +3,24 @@ - + - - + + - + - - + + - Rotary Positional Embeddings (RoPE) Experiment + Rotary Positional Embeddings with Relative distance (RoPER) Experiment @@ -70,29 +70,28 @@ -

Rotary Positional Embeddings (RoPE) Experiment

-

This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).

-

View Run

+

Rotary Positional Embeddings with Relative distance (RoPER) Experiment

-
14from labml import experiment
-15from labml.configs import calculate
-16from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
-17from labml_nn.transformers import TransformerConfigs
-18from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
+
11from labml import experiment
+12from labml.configs import calculate
+13from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
+14from labml_nn.transformers import TransformerConfigs
+15from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
-
+
-

Rotary PE attention

+

We inherit RoPE experiment and use it for arithmetic addition task.

+

We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) below.

-
+
18class Configs(RoPEConfigs, ArithmeticAutoregression):
@@ -103,21 +102,19 @@

Rotary PE attention

-
23class Configs(RoPEConfigs, ArithmeticAutoregression):  # , ArithmeticAutoregression):
-24    pass
+
26    pass
-
+
- +

Use Rotary Positional Embeddings with Relative distance (RoPER) in attention.

+
-
27def _rotary_value_pe_mha(c: TransformerConfigs):
-28    from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
-29    return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
+
29def _rotary_value_pe_mha(c: TransformerConfigs):
@@ -125,13 +122,11 @@

Rotary PE attention

-

Configuration options

- +
-
33calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
-34calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
-35calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
+
33    from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
+34    return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
@@ -139,10 +134,13 @@

Rotary PE attention

- +

Configuration options

+
-
38def main():
+
38calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
+39calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
+40calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
@@ -150,11 +148,10 @@

Rotary PE attention

-

Create experiment

- +
-
40    experiment.create(name="roper_addition", comment="rotary value", writers={'screen', 'labml', 'comet'})
+
43def main():
@@ -162,11 +159,11 @@

Rotary PE attention

-

Create configs

+

Create experiment

-
42    conf = Configs()
+
45    experiment.create(name="roper_addition", comment="rotary value 8", writers={'screen', 'labml', 'comet'})
@@ -174,12 +171,11 @@

Rotary PE attention

-

Override configurations

+

Create configs

-
44    experiment.configs(conf, {
-45        'max_digits': 6,
+
47    conf = Configs()
@@ -187,12 +183,12 @@

Rotary PE attention

-

No fixed positional embeddings

+

Override configurations

-
48        'transformer.src_embed': 'no_pos',
-49        'transformer.tgt_embed': 'no_pos',
+
49    experiment.configs(conf, {
+50        'max_digits': 8,
@@ -200,11 +196,12 @@

Rotary PE attention

-

Encoder with RoPE

+

No fixed positional embeddings

-
52        'transformer.encoder_attn': 'rotary_value',
+
53        'transformer.src_embed': 'no_pos',
+54        'transformer.tgt_embed': 'no_pos',
@@ -212,11 +209,11 @@

Rotary PE attention

-

'transformer.encoder_attn': 'rotary',

+

Encoder with RoPER attention

-
+
57        'transformer.encoder_attn': 'rotary_value',
@@ -224,11 +221,11 @@

Rotary PE attention

-

+

Encoder with RoPE attention 'transformer.encoder_attn': 'rotary',

-
56        'model': 'rotary_pe_transformer',
+
@@ -236,11 +233,11 @@

Rotary PE attention

-

Prompt separator is blank

+

-
59        'prompt_separator': '',
+
62        'model': 'rotary_pe_transformer',
@@ -248,11 +245,11 @@

Rotary PE attention

-

Starting prompt for sampling

+

Use a context size of

-
61        'prompt': '?x=123456789+1091919;',
+
65        'seq_len': 512,
@@ -260,11 +257,11 @@

Rotary PE attention

-

Use a context size of

+

Train for 32 epochs

-
64        'seq_len': 512,
+
67        'epochs': 20,
@@ -272,11 +269,11 @@

Rotary PE attention

-

Train for 32 epochs

+

Batch size

-
66        'epochs': 50,
+
69        'batch_size': 16,
@@ -284,11 +281,14 @@

Rotary PE attention

-

Batch size

+

Model size

-
68        'batch_size': 16,
+
72        'd_model': 128,
+73        'transformer.ffn.d_ff': 512,
+74        'transformer.n_heads': 4,
+75        'transformer.dropout': 0.0,
@@ -296,14 +296,13 @@

Rotary PE attention

-

Model size

+

Use Noam optimizer

-
71        'd_model': 128,
-72        'transformer.ffn.d_ff': 512,
-73        'transformer.n_heads': 4,
-74        'transformer.dropout': 0.0,
+
78        'optimizer.optimizer': 'Noam',
+79        'optimizer.learning_rate': 1.,
+80    })
@@ -311,13 +310,11 @@

Rotary PE attention

-

Use Noam optimizer

+

Set models for saving and loading

-
77        'optimizer.optimizer': 'Noam',
-78        'optimizer.learning_rate': 1.,
-79    })
+
83    experiment.add_pytorch_models({'model': conf.model})
@@ -325,11 +322,11 @@

Rotary PE attention

-

Set models for saving and loading

+

Start the experiment

-
82    experiment.add_pytorch_models({'model': conf.model})
+
86    with experiment.start():
@@ -337,11 +334,11 @@

Rotary PE attention

-

Start the experiment

+

Run training

-
85    with experiment.start():
+
88        conf.run()
@@ -349,24 +346,12 @@

Rotary PE attention

-

Run training

- -
-
-
87        conf.run()
-
- -
-
-

-
91if __name__ == '__main__':
-92    main()
+
92if __name__ == '__main__':
+93    main()

RoPER is work by Georges Harik (@gharik), and this implementation is based on his original code.

-

Rotary Positional Embeddings with Relative Distance (RoPER)

+

Rotary Positional Embeddings with Relative distance (RoPER)

Rotary Positional Embeddings (RoPE) includes relative positions in attention score calculation. However, the embeddings themselves do not get any positional information , except what it can get implicitly from causal attention.

RoPER adds relative positional information explicitly to value embeddings. Specifically, it adds the relative positions of the tokens it paid attention to. We use same rotary positional embeddings to rotate the values in attention, Then, after taking the weighted sum, we rotate the final in the opposite direction. Which is equivalent to rotating each of the values (before attention) relative to the current position.

Here's the training code for training a transformer model with RoPER on an arithmetic addition where we can see significant improvement over RoPE.

@@ -89,15 +89,16 @@

Relative distances in embeddings

Which gives,

That is, the weighted average of values rotated relative to current position.

+

Here's an experiment that uses RoPER on an arthmetic addition task.

-
116from typing import Optional
-117
-118import torch
+            
118from typing import Optional
 119
-120from labml_nn.transformers.mha import MultiHeadAttention
-121from labml_nn.transformers.rope import RotaryPositionalEmbeddings
+120import torch +121 +122from labml_nn.transformers.mha import MultiHeadAttention +123from labml_nn.transformers.rope import RotaryPositionalEmbeddings
@@ -110,7 +111,7 @@

RoPE module that rotates in the opposite direction

-
124class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
+
126class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
@@ -124,7 +125,7 @@

RoPE module that rotates in the opposite direction

-
131    def forward(self, x: torch.Tensor):
+
133    def forward(self, x: torch.Tensor):
@@ -136,7 +137,7 @@

RoPE module that rotates in the opposite direction

-
136        self._build_cache(x)
+
138        self._build_cache(x)
@@ -148,7 +149,7 @@

RoPE module that rotates in the opposite direction

-
139        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
+
141        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
@@ -160,7 +161,7 @@

RoPE module that rotates in the opposite direction

-
143        neg_half_x = self._neg_half(x_rope)
+
145        neg_half_x = self._neg_half(x_rope)
@@ -173,7 +174,7 @@

RoPE module that rotates in the opposite direction

-
159        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
+
161        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -185,7 +186,7 @@

RoPE module that rotates in the opposite direction

-
162        return torch.cat((x_rope, x_pass), dim=-1)
+
164        return torch.cat((x_rope, x_pass), dim=-1)
@@ -198,7 +199,7 @@

Multi-head attention with rotary positional embeddings

-
165class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
+
167class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
@@ -209,9 +210,9 @@

Multi-head attention with rotary positional embeddings

-
172    def __init__(self, heads: int, d_model: int,
-173                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-174                 dropout_prob: float = 0.1):
+
174    def __init__(self, heads: int, d_model: int,
+175                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+176                 dropout_prob: float = 0.0):
@@ -224,7 +225,7 @@

Multi-head attention with rotary positional embeddings

-
178        super().__init__(heads, d_model, dropout_prob, bias=False)
+
180        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -236,13 +237,13 @@

Multi-head attention with rotary positional embeddings

-
181        d_rope = int(self.d_k * rope_percentage)
-182        d_rope_value = int(self.d_k * rope_value_percentage)
-183
-184        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-186        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-187        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
183        d_rope = int(self.d_k * rope_percentage)
+184        d_rope_value = int(self.d_k * rope_value_percentage)
+185
+186        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+187        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+188        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+189        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
@@ -254,7 +255,7 @@

Calculate scores between queries and keys

-
189    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
191    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -266,7 +267,7 @@

Calculate scores between queries and keys

-
195        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
197        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
@@ -289,11 +290,11 @@

Calculate scores between queries and keys

-
197    def forward(self, *,
-198                query: torch.Tensor,
-199                key: torch.Tensor,
-200                value: torch.Tensor,
-201                mask: Optional[torch.Tensor] = None):
+
199    def forward(self, *,
+200                query: torch.Tensor,
+201                key: torch.Tensor,
+202                value: torch.Tensor,
+203                mask: Optional[torch.Tensor] = None):
@@ -309,10 +310,10 @@

Calculate scores between queries and keys

-
213        seq_len, batch_size, _ = query.shape
-214
-215        if mask is not None:
-216            mask = self.prepare_mask(mask, query.shape, key.shape)
+
215        seq_len, batch_size, _ = query.shape
+216
+217        if mask is not None:
+218            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -328,9 +329,9 @@

Calculate scores between queries and keys

-
220        query = self.query(query)
-221        key = self.key(key)
-222        value = self.value(value)
+
222        query = self.query(query)
+223        key = self.key(key)
+224        value = self.value(value)
@@ -343,7 +344,7 @@

Calculate scores between queries and keys

-
226        scores = self.get_scores(query, key)
+
228        scores = self.get_scores(query, key)
@@ -366,7 +367,7 @@

Calculate scores between queries and keys

-
229        scores *= self.scale
+
231        scores *= self.scale
@@ -378,8 +379,8 @@

Calculate scores between queries and keys

-
232        if mask is not None:
-233            scores = scores.masked_fill(mask == 0, float('-inf'))
+
234        if mask is not None:
+235            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -402,7 +403,7 @@

Calculate scores between queries and keys

-
237        attn = self.softmax(scores)
+
239        attn = self.softmax(scores)
@@ -414,7 +415,7 @@

Calculate scores between queries and keys

-
240        attn = self.dropout(attn)
+
242        attn = self.dropout(attn)
@@ -426,7 +427,7 @@

Calculate scores between queries and keys

-
243        value = self.value_rotary_pe(value)
+
245        value = self.value_rotary_pe(value)
@@ -449,7 +450,7 @@

Calculate scores between queries and keys

-
247        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
249        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
@@ -461,7 +462,7 @@

Calculate scores between queries and keys

-
250        x = self.value_reverse_rotary_pe(x)
+
252        x = self.value_reverse_rotary_pe(x)
@@ -473,7 +474,7 @@

Calculate scores between queries and keys

-
253        self.attn = attn.detach()
+
255        self.attn = attn.detach()
@@ -485,7 +486,7 @@

Calculate scores between queries and keys

-
256        x = x.reshape(seq_len, batch_size, -1)
+
258        x = x.reshape(seq_len, batch_size, -1)
@@ -497,7 +498,7 @@

Calculate scores between queries and keys

-
259        return self.output(x)
+
261        return self.output(x)
-
33    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
+
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
@@ -119,9 +119,9 @@

Arithmetic Dataset

-
40        self.n_sequences = n_sequences
-41        self.max_digits = max_digits
-42        self.seq_len = seq_len
+
41        self.n_sequences = n_sequences
+42        self.max_digits = max_digits
+43        self.seq_len = seq_len
@@ -133,7 +133,7 @@

Arithmetic Dataset

-
44        self.itos = list(string.digits + 'xe =\n?+;')
+
45        self.itos = list(string.digits + 'xe =\n?+;')
@@ -145,7 +145,7 @@

Arithmetic Dataset

-
46        self.stoi = {c: i for i, c in enumerate(self.itos)}
+
47        self.stoi = {c: i for i, c in enumerate(self.itos)}
@@ -158,8 +158,8 @@

Arithmetic Dataset

-
48    @staticmethod
-49    def make_int(n_digits: int):
+
49    @staticmethod
+50    def make_int(n_digits: int):
@@ -170,12 +170,12 @@

Arithmetic Dataset

-
53        res = 0
-54        for i in range(n_digits):
-55            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
-56            res = res * 10 + d
-57
-58        return res
+
54        res = 0
+55        for i in range(n_digits):
+56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
+57            res = res * 10 + d
+58
+59        return res
@@ -190,8 +190,8 @@

Arithmetic Dataset

-
60    @staticmethod
-61    def get_add_explanation(x: int, y: int):
+
61    @staticmethod
+62    def get_add_explanation(x: int, y: int):
@@ -202,17 +202,17 @@

Arithmetic Dataset

-
68        carry = 0
-69        e = 0
-70        explanation = []
-71        while x > 0 or y > 0 or carry > 0:
-72            rx, ry = x % 10, y % 10
-73            total = rx + ry + carry
-74            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
-75            x, y, carry = x // 10, y // 10, total // 10
-76            e += 1
-77
-78        return ' '.join(explanation)
+
69        carry = 0
+70        e = 0
+71        explanation = []
+72        while x > 0 or y > 0 or carry > 0:
+73            rx, ry = x % 10, y % 10
+74            total = rx + ry + carry
+75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
+76            x, y, carry = x // 10, y // 10, total // 10
+77            e += 1
+78
+79        return ' '.join(explanation)
@@ -225,7 +225,7 @@

Arithmetic Dataset

-
81    def make_add_problem(self):
+
82    def make_add_problem(self):
@@ -236,26 +236,23 @@

Arithmetic Dataset

-
85        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-86        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-87
-88        explanation = self.get_add_explanation(x, y)
-89        return f"x={x}+{y}; {explanation} x=={x + y}\n"
+
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+88
+89        explanation = self.get_add_explanation(x, y)
+90        return f"x={x}+{y}; {explanation} x=={x + y}\n"
-
+
- +

Get arithmetic problem and answer. This is used for evaluation.

+
-
91    def get_qa(self):
-92        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-93        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
-94
-95        return f'x={x}+{y};', f'{x + y}'
+
92    def get_qa(self):
@@ -266,25 +263,22 @@

Arithmetic Dataset

-
97    def get_packed_math_input(self):
-98        s_enc = []
-99        while len(s_enc) <= self.seq_len:
-100            s_part = self.make_add_problem()
-101            s_part_enc = self.encode('?' + s_part)
-102            s_enc = s_enc + s_part_enc
-103        return s_enc
+
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+98
+99        return f'x={x}+{y};', f'{x + y}'
-
+
- +

Generate multiple problems and pack them into a sequence.

+
-
105    def encode(self, s: str):
-106        return [self.stoi[c] for c in s]
+
101    def get_packed_math_input(self):
@@ -295,21 +289,24 @@

Arithmetic Dataset

-
108    def decode(self, arr: List[int]):
-109        return ''.join([self.itos[c] for c in arr])
+
105        s_enc = []
+106        while len(s_enc) <= self.seq_len:
+107            s_part = self.make_add_problem()
+108            s_part_enc = self.encode('?' + s_part)
+109            s_enc = s_enc + s_part_enc
+110        return s_enc
-
+
- +

Encode a given string

+
-
111    def __getitem__(self, idx):
-112        s = torch.tensor(self.get_packed_math_input())
-113        return s[:self.seq_len], s[1:self.seq_len + 1]
+
112    def encode(self, s: str):
@@ -320,63 +317,42 @@

Arithmetic Dataset

-
115    def __len__(self):
-116        return self.n_sequences
+
116        return [self.stoi[c] for c in s]
-
+
- +

Decode a list of token ids

+
-
119class ArithmeticAutoregression(NLPAutoRegressionConfigs):
-120    max_digits: int = 4
-121    train_sequences_per_epoch: int = 2 ** 12
-122    train_loader: DataLoader = 'arithmetic_train_loader'
-123    n_tests: int = 32
-124    validator = None
-125    inner_iterations = 4
-126
-127    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
118    def decode(self, arr: List[int]):
-
+
-

Sampling function to generate samples periodically while training

- +
-
129    def sample(self):
+
122        return ''.join([self.itos[c] for c in arr])
-
+
- +

Get a input and target pair for auto-regressive modelling

+
-
134        if self.training_loop.idx < 1:
-135            return
-136
-137        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
-138        qa = [dataset.get_qa() for _ in range(self.n_tests)]
-139        prompt = [p[0] for p in qa]
-140
-141        data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]])
-142        data = data.to(self.device)
-143
-144        finished = torch.zeros((len(prompt),)).bool().to(self.device)
-145        new_line = dataset.stoi['\n']
-146
-147        results = [p[0] for p in prompt]
+
124    def __getitem__(self, idx: int):
@@ -384,25 +360,23 @@

Sampling function to generate samples periodically while training

-

Sample 25 tokens

- +
-
150        for i in monit.iterate('Sample', self.seq_len - 1):
-151            if finished.sum() == len(finished):
-152                continue
+
128        s = torch.tensor(self.get_packed_math_input())
+129        return s[:self.seq_len], s[1:self.seq_len + 1]
-
+
-

Tokenize the prompt Get the model output

+

Number of sequences per epoch

-
156            output, *_ = self.model(data)
+
131    def __len__(self):
@@ -410,64 +384,470 @@

Sampling function to generate samples periodically while training

+ +
+
+
135        return self.n_sequences
+
+
+
+
+ +

Arithmetic Task Experiment Configurations

+ +
+
+
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
+
+
+
+
+ +

Maximum number of digits per operand integer

+ +
+
+
143    max_digits: int = 4
+
+
+
+
+ +

Number of training sequences per epoch

+ +
+
+
145    train_sequences_per_epoch: int = 2 ** 12
+
+
+
+
+ +

Training data loader

+ +
+
+
147    train_loader: DataLoader = 'arithmetic_train_loader'
+
+
+
+
+ +

Number of problems in evaluation

+ +
+
+
149    n_tests: int = 32
+
+
+
+
+ +

No need of a validation dataset

+ +
+
+
151    validator = None
+
+
+
+
+ +

Number of times to run evaluations per epoch

+ +
+
+
153    inner_iterations = 4
+
+
+
+
+ +

Number of tokens in the vocabulary

+ +
+
+
155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
+
+
+
+ +

Evaluation

+

We use the sampling function to evaluate the model on a set of problems

+ +
+
+
157    def sample(self):
+
+
+
+
+ +

Skip in the first epoch

+ +
+
+
165        if self.training_loop.idx < 1:
+166            return
+
+
+
+
+ +

Create a dataset to generate problems

+ +
+
+
169        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+
+
+
+
+ +

Get a set of problems and answers

+ +
+
+
171        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+
+
+
+
+ +

Collect the problems only

+ +
+
+
173        questions = [p[0] for p in qa]
+
+
+
+
+ +

Create a tensor with only the initial token

+ +
+
+
176        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
+
+
+
+
+ +

Move to device

+ +
+
+
178        data = data.to(self.device)
+
+
+
+
+ +

Number of sequences that have completed

+ +
+
+
181        finished = torch.zeros((len(questions),)).bool().to(self.device)
+
+
+
+
+ +

Token id of the new line character - this marks end of the answer

+ +
+
+
183        new_line = dataset.stoi['\n']
+
+
+
+
+ +

Sampled results

+ +
+
+
186        results = [p[0] for p in questions]
+
+
+
+
+ +

Sample upto sequence length

+ +
+
+
189        for i in monit.iterate('Sample', self.seq_len - 1):
+
+
+
+
+ +

If all the sequences have completed we skip this

+ +
+
+
191            if finished.sum() == len(finished):
+192                continue
+
+
+
+
+ +

Get the model output

+ +
+
+
195            output, *_ = self.model(data)
+
+
+
+
+

Get the model prediction (greedy)

-
158            output = output[-1].argmax(dim=-1)
-159
-160            finished = finished | (output == new_line)
-161            if finished.sum() == len(finished):
-162                continue
-163
-164            for j, p in enumerate(prompt):
-165                if len(p) > i + 1:
-166                    output[j] = dataset.stoi[p[i + 1]]
-167
-168            data = torch.cat([data, output[None, :]], dim=0)
-169
-170            for j, c in enumerate(output):
-171                results[j] += dataset.itos[c]
-172
-173        results = [r.split('\n')[0] for r in results]
-174
-175        res_sample = results[0].split(';')
-176        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
-177
-178        results = [r.split('x==')[-1] for r in results]
-179
-180        correct = 0
-181        for r, _qa in zip(results, qa):
-182            if r == _qa[1]:
-183                correct += 1
-184
-185        tracker.save('score', correct / len(results))
+
197            output = output[-1].argmax(dim=-1)
-
+
+

Find which sequences have finished

+ +
+
+
200            finished = finished | (output == new_line)
+
+
+
+
+ +

Skip if all have finished

+ +
+
+
202            if finished.sum() == len(finished):
+203                continue
+
+
+
+
+ +

Override with the question

+ +
+
+
206            for j, p in enumerate(questions):
+207                if len(p) > i + 1:
+208                    output[j] = dataset.stoi[p[i + 1]]
+
+
+
+
+ +

Add the next token to the input

+ +
+
+
211            data = torch.cat([data, output[None, :]], dim=0)
+
+
+
+
+ +

Get the sampled results

+ +
+
+
214            for j, c in enumerate(output):
+215                results[j] += dataset.itos[c]
+
+
+
+
+ +

Discard everything after the answer in the results

+ +
+
+
218        results = [r.split('\n')[0] for r in results]
+
+
+
+
+ +

Log a sample

+ +
+
+
221        res_sample = results[0].split(';')
+222        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+
+
+
+
+ +

Get the answers

+ +
+
+
225        results = [r.split('x==')[-1] for r in results]
+
+
+
+
+ +

Count the number of correct answers

+ +
+
+
228        correct = 0
+229        for r, _qa in zip(results, qa):
+230            if r == _qa[1]:
+231                correct += 1
+
+
+
+
+ +

Log the score

+ +
+
+
234        tracker.save('score', correct / len(results))
+
+
+
+
+ +

Training data loader

+ +
+
+
237@option(ArithmeticAutoregression.train_loader)
+238def arithmetic_train_loader(c: ArithmeticAutoregression):
+
+
+
+
+ + +
+
+
242    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+243                      batch_size=c.batch_size,
+244                      collate_fn=transpose_batch,
+245                      num_workers=4)
+
+
+
+
+ +

Code to test generated problems

+ +
+
+
248def _test():
+
+
+
+
+
-
188@option(ArithmeticAutoregression.train_loader)
-189def arithmetic_train_loader(c: ArithmeticAutoregression):
-190    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
-191                      batch_size=c.batch_size,
-192                      collate_fn=transpose_batch,
-193                      num_workers=4)
-194
-195
-196def _test():
-197    dataset = ArithmeticDataset(256, 8, 10)
-198
-199    print(dataset.decode(dataset.get_packed_math_input()))
-200
-201
-202if __name__ == '__main__':
-203    _test()
+
252    dataset = ArithmeticDataset(256, 8, 10)
+253
+254    print(dataset.decode(dataset.get_packed_math_input()))
+
+
+
+
+ +

+ +
+
+
258if __name__ == '__main__':
+259    _test()
-
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
+
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
+204        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -377,82 +378,69 @@

Multi-head attention with rotary positional embeddings

-

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value - might make sense.

- -
-
-
207        super().__init__(heads, d_model, dropout_prob, bias=False)
-
-
-
-
-

Rotary positional embedding layers

-
210        d_rope = int(self.d_k * rope_percentage)
-211        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-212        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+
207        d_rope = int(self.d_k * rope_percentage)
+208        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+209        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-
+

Calculate scores between queries and keys

-
214    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
211    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
-
+

Calculate dot-product with RoPE

-
220        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
217        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
-
+

Testing RoPE with a simple example

-
223def _test_rotary():
+
220def _test_rotary():
-
+
-
227    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-228    x = x[:, None, None, :]
-229    inspect(x)
+            
224    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+225    x = x[:, None, None, :]
+226    inspect(x)
+227
+228    rotary_pe = RotaryPositionalEmbeddings(3)
+229    inspect(rotary_pe(x))
 230
-231    rotary_pe = RotaryPositionalEmbeddings(3)
-232    inspect(rotary_pe(x))
-233
-234
-235if __name__ == '__main__':
-236    _test_rotary()
+231 +232if __name__ == '__main__': +233 _test_rotary()
@@ -220,60 +221,47 @@

Multi-head attention with rotary positional embeddings

-

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value - might make sense.

- -
-
-
180        super().__init__(heads, d_model, dropout_prob, bias=False)
-
-
-
-
-

Rotary positional embedding layers

-
183        d_rope = int(self.d_k * rope_percentage)
-184        d_rope_value = int(self.d_k * rope_value_percentage)
-185
-186        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-187        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-188        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-189        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
180        d_rope = int(self.d_k * rope_percentage)
+181        d_rope_value = int(self.d_k * rope_value_percentage)
+182
+183        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+184        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+185        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+186        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
-
+

Calculate scores between queries and keys

-
191    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
188    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
-
+

Calculate dot-product with RoPE

-
197        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
194        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
-
+

query , key @@ -290,17 +278,17 @@

Calculate scores between queries and keys

-
199    def forward(self, *,
-200                query: torch.Tensor,
-201                key: torch.Tensor,
-202                value: torch.Tensor,
-203                mask: Optional[torch.Tensor] = None):
+
196    def forward(self, *,
+197                query: torch.Tensor,
+198                key: torch.Tensor,
+199                value: torch.Tensor,
+200                mask: Optional[torch.Tensor] = None):
-
+

query , key @@ -310,16 +298,16 @@

Calculate scores between queries and keys

-
215        seq_len, batch_size, _ = query.shape
-216
-217        if mask is not None:
-218            mask = self.prepare_mask(mask, query.shape, key.shape)
+
212        seq_len, batch_size, _ = query.shape
+213
+214        if mask is not None:
+215            mask = self.prepare_mask(mask, query.shape, key.shape)
-
+

Prepare query , key @@ -329,28 +317,28 @@

Calculate scores between queries and keys

-
222        query = self.query(query)
-223        key = self.key(key)
-224        value = self.value(value)
+
219        query = self.query(query)
+220        key = self.key(key)
+221        value = self.value(value)
-
+

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

-
228        scores = self.get_scores(query, key)
+
225        scores = self.get_scores(query, key)
-
+

Scale scores

231        scores *= self.scale
+
228        scores *= self.scale
-
+

Apply mask

-
234        if mask is not None:
-235            scores = scores.masked_fill(mask == 0, float('-inf'))
+
231        if mask is not None:
+232            scores = scores.masked_fill(mask == 0, float('-inf'))
-
+

attention along the key sequence dimension

239        attn = self.softmax(scores)
+
236        attn = self.softmax(scores)
-
+

Apply dropout

-
242        attn = self.dropout(attn)
+
239        attn = self.dropout(attn)
-
+

Rotate value embeddings before taking the weighted sum so that they contain positional information

-
245        value = self.value_rotary_pe(value)
+
242        value = self.value_rotary_pe(value)
-
+

Multiply by values

249        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
246        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
-
+

Rotate in the opposite direction so that each embedding hold the relative positions

-
252        x = self.value_reverse_rotary_pe(x)
+
249        x = self.value_reverse_rotary_pe(x)
-
+

Save attentions for any other calculations

-
255        self.attn = attn.detach()
+
252        self.attn = attn.detach()
-
+

Concatenate multiple heads

-
258        x = x.reshape(seq_len, batch_size, -1)
+
255        x = x.reshape(seq_len, batch_size, -1)
-
+

Output layer

-
261        return self.output(x)
+
258        return self.output(x)
-
149    n_tests: int = 32
+
149    n_tests: int = 64
@@ -496,7 +496,8 @@

Evaluation

-
157    def sample(self):
+
157    @torch.no_grad()
+158    def sample(self):
@@ -508,8 +509,8 @@

Evaluation

-
165        if self.training_loop.idx < 1:
-166            return
+
166        if self.training_loop.idx < 1:
+167            return
@@ -521,7 +522,7 @@

Evaluation

-
169        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+
170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
@@ -533,7 +534,7 @@

Evaluation

-
171        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+
172        qa = [dataset.get_qa() for _ in range(self.n_tests)]
@@ -545,7 +546,7 @@

Evaluation

-
173        questions = [p[0] for p in qa]
+
174        questions = [p[0] for p in qa]
@@ -557,7 +558,7 @@

Evaluation

-
176        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
+
177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
@@ -569,7 +570,7 @@

Evaluation

-
178        data = data.to(self.device)
+
179        data = data.to(self.device)
@@ -581,7 +582,7 @@

Evaluation

-
181        finished = torch.zeros((len(questions),)).bool().to(self.device)
+
182        finished = torch.zeros((len(questions),)).bool().to(self.device)
@@ -593,7 +594,7 @@

Evaluation

-
183        new_line = dataset.stoi['\n']
+
184        new_line = dataset.stoi['\n']
@@ -605,7 +606,7 @@

Evaluation

-
186        results = [p[0] for p in questions]
+
187        results = [p[0] for p in questions]
@@ -617,7 +618,7 @@

Evaluation

-
189        for i in monit.iterate('Sample', self.seq_len - 1):
+
190        for i in monit.iterate('Sample', self.seq_len - 1):
@@ -629,8 +630,8 @@

Evaluation

-
191            if finished.sum() == len(finished):
-192                continue
+
192            if finished.sum() == len(finished):
+193                continue
@@ -642,7 +643,7 @@

Evaluation

-
195            output, *_ = self.model(data)
+
196            output, *_ = self.model(data)
@@ -654,7 +655,7 @@

Evaluation

-
197            output = output[-1].argmax(dim=-1)
+
198            output = output[-1].argmax(dim=-1)
@@ -666,7 +667,7 @@

Evaluation

-
200            finished = finished | (output == new_line)
+
201            finished = finished | (output == new_line)
@@ -678,8 +679,8 @@

Evaluation

-
202            if finished.sum() == len(finished):
-203                continue
+
203            if finished.sum() == len(finished):
+204                continue
@@ -691,9 +692,9 @@

Evaluation

-
206            for j, p in enumerate(questions):
-207                if len(p) > i + 1:
-208                    output[j] = dataset.stoi[p[i + 1]]
+
207            for j, p in enumerate(questions):
+208                if len(p) > i + 1:
+209                    output[j] = dataset.stoi[p[i + 1]]
@@ -705,7 +706,7 @@

Evaluation

-
211            data = torch.cat([data, output[None, :]], dim=0)
+
212            data = torch.cat([data, output[None, :]], dim=0)
@@ -717,8 +718,8 @@

Evaluation

-
214            for j, c in enumerate(output):
-215                results[j] += dataset.itos[c]
+
215            for j, c in enumerate(output):
+216                results[j] += dataset.itos[c]
@@ -730,7 +731,7 @@

Evaluation

-
218        results = [r.split('\n')[0] for r in results]
+
219        results = [r.split('\n')[0] for r in results]
@@ -742,8 +743,8 @@

Evaluation

-
221        res_sample = results[0].split(';')
-222        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+
222        res_sample = results[0].split(';')
+223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
@@ -755,7 +756,7 @@

Evaluation

-
225        results = [r.split('x==')[-1] for r in results]
+
226        results = [r.split('x==')[-1] for r in results]
@@ -767,10 +768,10 @@

Evaluation

-
228        correct = 0
-229        for r, _qa in zip(results, qa):
-230            if r == _qa[1]:
-231                correct += 1
+
229        correct = 0
+230        for r, _qa in zip(results, qa):
+231            if r == _qa[1]:
+232                correct += 1
@@ -782,7 +783,7 @@

Evaluation

-
234        tracker.save('score', correct / len(results))
+
235        tracker.save('score', correct / len(results))
@@ -794,8 +795,8 @@

Evaluation

-
237@option(ArithmeticAutoregression.train_loader)
-238def arithmetic_train_loader(c: ArithmeticAutoregression):
+
238@option(ArithmeticAutoregression.train_loader)
+239def arithmetic_train_loader(c: ArithmeticAutoregression):
@@ -806,10 +807,10 @@

Evaluation

-
242    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
-243                      batch_size=c.batch_size,
-244                      collate_fn=transpose_batch,
-245                      num_workers=4)
+
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+244                      batch_size=c.batch_size,
+245                      collate_fn=transpose_batch,
+246                      num_workers=4)
@@ -821,7 +822,7 @@

Evaluation

-
248def _test():
+
249def _test():
@@ -832,9 +833,9 @@

Evaluation

-
252    dataset = ArithmeticDataset(256, 8, 10)
-253
-254    print(dataset.decode(dataset.get_packed_math_input()))
+
253    dataset = ArithmeticDataset(256, 8, 10)
+254
+255    print(dataset.decode(dataset.get_packed_math_input()))
@@ -846,8 +847,8 @@

Evaluation

-
258if __name__ == '__main__':
-259    _test()
+
259if __name__ == '__main__':
+260    _test()
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
-204        super().__init__(heads, d_model, dropout_prob, bias=False)
+204 super().__init__(heads, d_model, dropout_prob)
diff --git a/docs/transformers/rope/value_pe/arithmetic_experiment.html b/docs/transformers/rope/value_pe/arithmetic_experiment.html index 45d9d634..bd52b4dc 100644 --- a/docs/transformers/rope/value_pe/arithmetic_experiment.html +++ b/docs/transformers/rope/value_pe/arithmetic_experiment.html @@ -163,7 +163,7 @@

Rotary Positional Embeddings with Relative distance (Ro

-
45    experiment.create(name="roper_addition", comment="rotary value 8", writers={'screen', 'labml', 'comet'})
+
45    experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
49    experiment.configs(conf, {
-50        'max_digits': 8,
+50 'max_digits': 7,
-
78        'optimizer.optimizer': 'Noam',
-79        'optimizer.learning_rate': 1.,
+            
78        'optimizer.optimizer': 'Adam',
+79        'optimizer.learning_rate': 2.5e-4,
 80    })
diff --git a/docs/transformers/rope/value_pe/experiment.html b/docs/transformers/rope/value_pe/experiment.html index 76690ddb..109a7605 100644 --- a/docs/transformers/rope/value_pe/experiment.html +++ b/docs/transformers/rope/value_pe/experiment.html @@ -116,7 +116,7 @@

Rotary PE attention

26def _rotary_value_pe_mha(c: TransformerConfigs):
 27    from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
-28    return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5)
+28 return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
@@ -153,7 +153,7 @@

Rotary PE attention

-
39    experiment.create(name="rotary_pe_transformer", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'})
+
39    experiment.create(name="rotary_shakespeare", comment="rotary value", writers={'screen', 'labml'})
@@ -286,7 +286,7 @@

Rotary PE attention

-
65        'seq_len': 128,
+
65        'seq_len': 512,
@@ -298,7 +298,7 @@

Rotary PE attention

-
67        'epochs': 32,
+
67        'epochs': 24,
@@ -310,7 +310,7 @@

Rotary PE attention

-
69        'batch_size': 4,
+
69        'batch_size': 16,
@@ -322,7 +322,7 @@

Rotary PE attention

-
72        'inner_iterations': 10,
+
72        'inner_iterations': 4,
@@ -334,9 +334,9 @@

Rotary PE attention

-
75        'd_model': 256,
-76        'transformer.ffn.d_ff': 1024,
-77        'transformer.n_heads': 8,
+            
75        'd_model': 128,
+76        'transformer.ffn.d_ff': 512,
+77        'transformer.n_heads': 4,
 78        'transformer.dropout': 0.0,
@@ -345,12 +345,12 @@

Rotary PE attention

-

Use Noam optimizer

+

Use Adam optimizer

-
81        'optimizer.optimizer': 'Noam',
-82        'optimizer.learning_rate': 1.,
+            
81        'optimizer.optimizer': 'Adam',
+82        'optimizer.learning_rate': 2.5e-4,
 83
 84        'dataloader_shuffle_with_replacement': True
 85    })
diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html index 30281ed9..fa4f0e44 100644 --- a/docs/transformers/rope/value_pe/index.html +++ b/docs/transformers/rope/value_pe/index.html @@ -97,8 +97,7 @@

Relative distances in embeddings

119 120import torch 121 -122from labml_nn.transformers.mha import MultiHeadAttention -123from labml_nn.transformers.rope import RotaryPositionalEmbeddings
+122from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention
@@ -111,7 +110,7 @@

RoPE module that rotates in the opposite direction

-
126class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
+
125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
@@ -125,7 +124,7 @@

RoPE module that rotates in the opposite direction

-
133    def forward(self, x: torch.Tensor):
+
132    def forward(self, x: torch.Tensor):
@@ -137,7 +136,7 @@

RoPE module that rotates in the opposite direction

-
138        self._build_cache(x)
+
137        self._build_cache(x)
@@ -149,7 +148,7 @@

RoPE module that rotates in the opposite direction

-
141        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
+
140        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
@@ -161,7 +160,7 @@

RoPE module that rotates in the opposite direction

-
145        neg_half_x = self._neg_half(x_rope)
+
144        neg_half_x = self._neg_half(x_rope)
@@ -174,7 +173,7 @@

RoPE module that rotates in the opposite direction

-
161        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
+
160        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -186,7 +185,7 @@

RoPE module that rotates in the opposite direction

-
164        return torch.cat((x_rope, x_pass), dim=-1)
+
163        return torch.cat((x_rope, x_pass), dim=-1)
@@ -199,7 +198,7 @@

Multi-head attention with rotary positional embeddings

-
167class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
+
166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
@@ -210,10 +209,10 @@

Multi-head attention with rotary positional embeddings

-
174    def __init__(self, heads: int, d_model: int,
-175                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-176                 dropout_prob: float = 0.0):
-177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
173    def __init__(self, heads: int, d_model: int,
+174                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+175                 dropout_prob: float = 0.0):
+176        super().__init__(heads, d_model, rope_percentage, dropout_prob)
@@ -225,13 +224,10 @@

Multi-head attention with rotary positional embeddings

-
180        d_rope = int(self.d_k * rope_percentage)
-181        d_rope_value = int(self.d_k * rope_value_percentage)
-182
-183        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-184        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-186        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
179        d_rope_value = int(self.d_k * rope_value_percentage)
+180
+181        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+182        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
@@ -239,30 +235,6 @@

Multi-head attention with rotary positional embeddings

-

Calculate scores between queries and keys

- -
-
-
188    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
-
- -
-
- -

Calculate dot-product with RoPE

- -
-
-
194        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
-
-
-
-
-

query , key and value @@ -278,17 +250,17 @@

Calculate scores between queries and keys

-
196    def forward(self, *,
-197                query: torch.Tensor,
-198                key: torch.Tensor,
-199                value: torch.Tensor,
-200                mask: Optional[torch.Tensor] = None):
+
184    def forward(self, *,
+185                query: torch.Tensor,
+186                key: torch.Tensor,
+187                value: torch.Tensor,
+188                mask: Optional[torch.Tensor] = None):
-
+

query , key @@ -298,16 +270,16 @@

Calculate scores between queries and keys

-
212        seq_len, batch_size, _ = query.shape
-213
-214        if mask is not None:
-215            mask = self.prepare_mask(mask, query.shape, key.shape)
+
200        seq_len, batch_size, _ = query.shape
+201
+202        if mask is not None:
+203            mask = self.prepare_mask(mask, query.shape, key.shape)
-
+

Prepare query , key @@ -317,28 +289,28 @@

Calculate scores between queries and keys

-
219        query = self.query(query)
-220        key = self.key(key)
-221        value = self.value(value)
+
207        query = self.query(query)
+208        key = self.key(key)
+209        value = self.value(value)
-
+

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

-
225        scores = self.get_scores(query, key)
+
213        scores = self.get_scores(query, key)
-
+

Scale scores

228        scores *= self.scale
+
216        scores *= self.scale
-
+

Apply mask

-
231        if mask is not None:
-232            scores = scores.masked_fill(mask == 0, float('-inf'))
+
219        if mask is not None:
+220            scores = scores.masked_fill(mask == 0, float('-inf'))
-
+

attention along the key sequence dimension

236        attn = self.softmax(scores)
+
224        attn = self.softmax(scores)
-
+

Apply dropout

-
239        attn = self.dropout(attn)
+
227        attn = self.dropout(attn)
-
+

Rotate value embeddings before taking the weighted sum so that they contain positional information

-
242        value = self.value_rotary_pe(value)
+
230        value = self.value_rotary_pe(value)
-
+

Multiply by values

246        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
-
+

Rotate in the opposite direction so that each embedding hold the relative positions

-
249        x = self.value_reverse_rotary_pe(x)
+
237        x = self.value_reverse_rotary_pe(x)
-
+

Save attentions for any other calculations

-
252        self.attn = attn.detach()
+
240        self.attn = attn.detach()
-
+

Concatenate multiple heads

-
255        x = x.reshape(seq_len, batch_size, -1)
+
243        x = x.reshape(seq_len, batch_size, -1)
-
+

Output layer

-
258        return self.output(x)
+
246        return self.output(x)
@@ -293,7 +293,7 @@

Trainer configurations

-
80    accuracy = Accuracy()
+
80    accuracy = AccuracyMovingAvg()
diff --git a/setup.py b/setup.py index c923f274..2c01dc17 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ 'test', 'test.*')), install_requires=['labml>=0.4.151', - 'labml-helpers>=0.4.86', + 'labml-helpers>=0.4.87', 'torch', 'torchtext', 'torchvision', From af233cfd73d8346059dc657cf120c2449450ea0a Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 7 Jun 2022 14:00:12 +0530 Subject: [PATCH 22/27] copy multiple --- .../{copy_perm.py => copy_perm/__init__.py} | 0 labml_nn/experiments/copy_perm/continous.py | 179 ++++++++++++++++++ labml_nn/experiments/nlp_autoregression.py | 13 +- .../rope/value_pe/experiments/__init__.py | 0 .../arithmetic_experiment.py | 0 .../{ => experiments}/copy_experiment.py | 0 .../rope/value_pe/experiments/copy_repeat.py | 93 +++++++++ .../value_pe/{ => experiments}/experiment.py | 0 8 files changed, 277 insertions(+), 8 deletions(-) rename labml_nn/experiments/{copy_perm.py => copy_perm/__init__.py} (100%) create mode 100644 labml_nn/experiments/copy_perm/continous.py create mode 100644 labml_nn/transformers/rope/value_pe/experiments/__init__.py rename labml_nn/transformers/rope/value_pe/{ => experiments}/arithmetic_experiment.py (100%) rename labml_nn/transformers/rope/value_pe/{ => experiments}/copy_experiment.py (100%) create mode 100644 labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py rename labml_nn/transformers/rope/value_pe/{ => experiments}/experiment.py (100%) diff --git a/labml_nn/experiments/copy_perm.py b/labml_nn/experiments/copy_perm/__init__.py similarity index 100% rename from labml_nn/experiments/copy_perm.py rename to labml_nn/experiments/copy_perm/__init__.py diff --git a/labml_nn/experiments/copy_perm/continous.py b/labml_nn/experiments/copy_perm/continous.py new file mode 100644 index 00000000..c462b16c --- /dev/null +++ b/labml_nn/experiments/copy_perm/continous.py @@ -0,0 +1,179 @@ +import random +from typing import List + +import torch +from torch.utils.data import DataLoader, Dataset + +from labml import tracker +from labml.configs import option +from labml_helpers.train_valid import BatchIndex +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch + + +class CopyPermRepeatDataset(Dataset): + """ + """ + + def __init__(self, seq_len: int, substr_len: int, rnd_len: int, n_sequences: int): + """ + :param seq_len: is the sequence length of generated math problems. + We fill as many problems as possible upto this length + """ + self.rnd_len = rnd_len + self.substr_len = substr_len + self.n_sequences = n_sequences + self.seq_len = seq_len + self.letters = 'acgt' # string.ascii_lowercase # '01' # 'acgt' # + # Token id to string + self.itos = list(self.letters + '>') + # Character to token id + self.stoi = {c: i for i, c in enumerate(self.itos)} + + def random_string(self, n_chars: int): + return ''.join(random.choice(self.letters) for _ in range(n_chars)) + + def generate_problem(self): + pure = self.random_string(self.substr_len) + out = pure + mask = [False] * len(out) + while len(out) <= self.seq_len: + s = self.random_string(random.randrange(1, self.rnd_len)) + out += s + '>' + mask += [False] * (len(s) + 1) + pure += s + + offset = random.randrange(0, len(pure) - self.substr_len) + copy = pure[offset:offset + self.substr_len] + + out += copy + mask += [False] + [True] * (self.substr_len - 1) + pure += copy + + return out, mask + + def encode(self, s: str): + """ + Encode a given string + """ + return [self.stoi[c] for c in s] + + def decode(self, arr: List[int]): + """ + Decode a list of token ids + """ + return ''.join([self.itos[c] for c in arr]) + + def __getitem__(self, idx: int): + """ + Get a input and target pair for auto-regressive modelling + """ + s, mask = self.generate_problem() + s = torch.tensor(self.encode(s)) + mask = torch.tensor(mask) + target = s * mask + -1 * (~mask) + return s[:self.seq_len], target[1:self.seq_len + 1] + + def __len__(self): + """ + Number of sequences per epoch + """ + return self.n_sequences + + +class CopyRepeatAutoregression(NLPAutoRegressionConfigs): + """ + ## Arithmetic Task Experiment Configurations + """ + # Number of training sequences per epoch + train_sequences_per_epoch: int = 2 ** 12 + # Training data loader + train_loader: DataLoader = 'copy_train_loader' + # Number of problems in evaluation + n_tests: int = 64 + # No need of a validation dataset + validator = None + # Number of times to run evaluations per epoch + inner_iterations = 4 + # Number of tokens in the vocabulary + n_tokens = len(CopyPermRepeatDataset(1, 1, 1, 1).itos) + + substr_len: int = 16 + rnd_len: int = 16 + + @torch.no_grad() + def sample(self): + pass + + def step(self, batch: any, batch_idx: BatchIndex): + """ + ### Training or validation step + """ + + # Set training/eval mode + self.model.train(self.mode.is_train) + + # Move data to the device + data, target = batch[0].to(self.device), batch[1].to(self.device) + + # Update global step (number of tokens processed) when in training mode + if self.mode.is_train: + tracker.add_global_step(data.shape[0] * data.shape[1]) + + # Whether to capture model outputs + with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): + # Get model outputs. + # It's returning a tuple for states when using RNNs. + # This is not implemented yet. 😜 + output, *_ = self.model(data) + + # Calculate and log loss + loss = self.loss_func(output, target) + tracker.add("loss.", loss) + + # Calculate and log accuracy + self.accuracy(output, target) + self.accuracy.track() + + self.other_metrics(output, target) + + # Train the model + if self.mode.is_train: + # Calculate gradients + loss.backward() + # Clip gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) + # Take optimizer step + self.optimizer.step() + # Log the model parameters and gradients on last batch of every epoch + if batch_idx.is_last and self.is_log_model_params_grads: + tracker.add('model', self.model) + # Clear the gradients + self.optimizer.zero_grad() + + # Save the tracked metrics + tracker.save() + + +@option(CopyRepeatAutoregression.train_loader) +def copy_train_loader(c: CopyRepeatAutoregression): + """ + Training data loader + """ + return DataLoader(CopyPermRepeatDataset(c.seq_len, c.substr_len, c.rnd_len, c.train_sequences_per_epoch), + batch_size=c.batch_size, + collate_fn=transpose_batch) + # num_workers=4) + + +def _test(): + """ + Code to test generated problems + """ + dataset = CopyPermRepeatDataset(32, 8, 8, 1) + + print(dataset.generate_problem()) + + +# +if __name__ == '__main__': + _test() diff --git a/labml_nn/experiments/nlp_autoregression.py b/labml_nn/experiments/nlp_autoregression.py index 1f4d1f40..cdcd0cc6 100644 --- a/labml_nn/experiments/nlp_autoregression.py +++ b/labml_nn/experiments/nlp_autoregression.py @@ -19,7 +19,7 @@ from labml.logger import Text from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy, AccuracyMovingAvg +from labml_helpers.metrics.accuracy import AccuracyMovingAvg from labml_helpers.module import Module from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.optimizers.configs import OptimizerConfigs @@ -30,9 +30,9 @@ class CrossEntropyLoss(Module): ### Cross entropy loss """ - def __init__(self): + def __init__(self, ignore_index: int = -100): super().__init__() - self.loss = nn.CrossEntropyLoss() + self.loss = nn.CrossEntropyLoss(ignore_index=ignore_index) def forward(self, outputs, targets): return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) @@ -75,7 +75,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): is_save_models = True # Loss function - loss_func = CrossEntropyLoss() + loss_func = CrossEntropyLoss(ignore_index=-1) # Accuracy function accuracy = AccuracyMovingAvg() # Model embedding size @@ -297,10 +297,7 @@ def transpose_batch(batch): transposed_data = list(zip(*batch)) # Stack the batch along the second dimension `dim=1` - src = torch.stack(transposed_data[0], dim=1) - tgt = torch.stack(transposed_data[1], dim=1) - - return src, tgt + return tuple(torch.stack(d, dim=1) for d in transposed_data) @option(NLPAutoRegressionConfigs.train_loader) diff --git a/labml_nn/transformers/rope/value_pe/experiments/__init__.py b/labml_nn/transformers/rope/value_pe/experiments/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_experiment.py similarity index 100% rename from labml_nn/transformers/rope/value_pe/arithmetic_experiment.py rename to labml_nn/transformers/rope/value_pe/experiments/arithmetic_experiment.py diff --git a/labml_nn/transformers/rope/value_pe/copy_experiment.py b/labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py similarity index 100% rename from labml_nn/transformers/rope/value_pe/copy_experiment.py rename to labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py diff --git a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py new file mode 100644 index 00000000..40c3ac85 --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py @@ -0,0 +1,93 @@ +""" +--- +title: Rotary Positional Embeddings with Relative distance (RoPER) Experiment +summary: This experiment trains a transformer model with Rotary Positional Embeddings with + Relative Distance (RoPER) on the arithmetic addition task. +--- + +# Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) Experiment +""" + +from labml import experiment +from labml.configs import calculate +from labml_nn.experiments.copy_perm import CopyAutoregression +from labml_nn.experiments.copy_perm.continous import CopyRepeatAutoregression +from labml_nn.transformers import TransformerConfigs +from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs + + +class Configs(RoPEConfigs, CopyRepeatAutoregression): + """ + We inherit [RoPE experiment](../experiment.html) and use it for + [arithmetic addition task](../../experiments/arithmetic_dataset.html). + + We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) + below. + """ + pass + + +def _rotary_value_pe_mha(c: TransformerConfigs): + """ + Use Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) in attention. + """ + from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.) + + +# Configuration options +calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha) + + +def main(): + # Create experiment + experiment.create(name="roper_copy", comment="rotary rl 01", writers={'screen', 'labml'}) + # Create configs + conf = Configs() + # Override configurations + experiment.configs(conf, { + # No fixed positional embeddings + 'transformer.src_embed': 'no_pos', + 'transformer.tgt_embed': 'no_pos', + + # Encoder with RoPER attention + # 'transformer.encoder_attn': 'rotary_value', + # Encoder with RoPE attention + 'transformer.encoder_attn': 'relative', + + # + 'model': 'rotary_pe_transformer', + + # Use a context size of $256$ + 'seq_len': 512, + # Train for 32 epochs + 'epochs': 20, + # Batch size $4$ + 'batch_size': 16, + + # Model size + 'd_model': 128, + 'transformer.ffn.d_ff': 512, + 'transformer.n_heads': 4, + 'transformer.n_layers': 3, + 'transformer.dropout': 0.0, + + # Use [Adam optimizer](../../optimizers/noam.html) + 'optimizer.optimizer': 'Adam', + 'optimizer.learning_rate': 2.5e-4, + }) + + # Set models for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # Run training + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/transformers/rope/value_pe/experiment.py b/labml_nn/transformers/rope/value_pe/experiments/experiment.py similarity index 100% rename from labml_nn/transformers/rope/value_pe/experiment.py rename to labml_nn/transformers/rope/value_pe/experiments/experiment.py From 3582cc97cfb211503a0ac087b5c6301fc3fc3d05 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 8 Jun 2022 14:45:24 +0530 Subject: [PATCH 23/27] results --- labml_nn/experiments/copy_perm/continous.py | 2 +- .../rope/value_pe/experiments/copy_repeat.py | 11 +- .../rope/value_pe/experiments/results.ipynb | 250 ++++++++++++++++++ run_roper.sh | 7 + 4 files changed, 264 insertions(+), 6 deletions(-) create mode 100644 labml_nn/transformers/rope/value_pe/experiments/results.ipynb create mode 100755 run_roper.sh diff --git a/labml_nn/experiments/copy_perm/continous.py b/labml_nn/experiments/copy_perm/continous.py index c462b16c..738f6433 100644 --- a/labml_nn/experiments/copy_perm/continous.py +++ b/labml_nn/experiments/copy_perm/continous.py @@ -98,7 +98,7 @@ class CopyRepeatAutoregression(NLPAutoRegressionConfigs): n_tokens = len(CopyPermRepeatDataset(1, 1, 1, 1).itos) substr_len: int = 16 - rnd_len: int = 16 + rnd_len: int = 12 @torch.no_grad() def sample(self): diff --git a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py index 40c3ac85..53c31336 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +++ b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py @@ -10,7 +10,6 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.copy_perm import CopyAutoregression from labml_nn.experiments.copy_perm.continous import CopyRepeatAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs @@ -43,19 +42,21 @@ def _rotary_value_pe_mha(c: TransformerConfigs): def main(): # Create experiment - experiment.create(name="roper_copy", comment="rotary rl 01", writers={'screen', 'labml'}) + experiment.create(name="roper_copy_repeat", comment="rotary acgt", writers={'screen', 'sqlite'}) # Create configs conf = Configs() # Override configurations experiment.configs(conf, { + 'substr_len': 16, + # No fixed positional embeddings 'transformer.src_embed': 'no_pos', 'transformer.tgt_embed': 'no_pos', # Encoder with RoPER attention - # 'transformer.encoder_attn': 'rotary_value', + 'transformer.encoder_attn': 'rotary_value', # Encoder with RoPE attention - 'transformer.encoder_attn': 'relative', + # 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer', @@ -63,7 +64,7 @@ def main(): # Use a context size of $256$ 'seq_len': 512, # Train for 32 epochs - 'epochs': 20, + 'epochs': 16, # Batch size $4$ 'batch_size': 16, diff --git a/labml_nn/transformers/rope/value_pe/experiments/results.ipynb b/labml_nn/transformers/rope/value_pe/experiments/results.ipynb new file mode 100644 index 00000000..a910de60 --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/experiments/results.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a779760d-52b3-41e2-98d2-adb98b8f2a93", + "metadata": {}, + "outputs": [], + "source": [ + "from labml import analytics\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4db66c78-1949-40e3-b96d-6c797a5f061a", + "metadata": {}, + "outputs": [], + "source": [ + "ind_rope = analytics.runs(*analytics.get_experiment_runs('rope_copy_repeat'))\n", + "ind_roper = analytics.runs(*analytics.get_experiment_runs('roper_copy_repeat'))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "89bb8bfe-b2be-4542-a4e5-28daab816ec1", + "metadata": {}, + "outputs": [], + "source": [ + "analytics.set_preferred_db('sqlite')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1fd1ead4-ecd2-448a-be68-3edf25d32ed7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['accuracy_train', 'loss_train', 'time_loop']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(ind_rope)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "65ee55a7-5100-40c9-b80c-0ec32d317186", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.VConcatChart(...)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "analytics.distribution((ind_rope.loss_train + ind_roper.loss_train)[16_000_000:],\n", + " color_scheme='blueorange',\n", + " width=800, height=600, height_minimap=50, levels=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b5ef8769-c4a6-4e88-8485-57a9a02da497", + "metadata": {}, + "outputs": [], + "source": [ + "d_rope, _ = analytics.indicator_data((ind_rope.loss_train))\n", + "d_rope = [d[:, 5] for d in d_rope]\n", + "d_rope = np.stack(d_rope).mean(axis=0)\n", + "\n", + "d_roper, _ = analytics.indicator_data((ind_roper.loss_train))\n", + "d_roper = [d[:, 5] for d in d_roper]\n", + "d_roper = np.stack(d_roper).mean(axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e737046d-ffb8-47f0-b742-d7241914061a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.VConcatChart(...)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "analytics.distribution([d_rope[50:], d_roper[50:]], width=800, height=600, height_minimap=50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dd17505-e567-43a0-8a2f-e092ae954183", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/run_roper.sh b/run_roper.sh new file mode 100755 index 00000000..1127f552 --- /dev/null +++ b/run_roper.sh @@ -0,0 +1,7 @@ +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +PYTHONPATH="${PYTHONPATH}:$(pwd):$(pwd)/src" python labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py From cd4e59840ebc1db22fe00f27e9a13572a1cd01b7 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 9 Jun 2022 15:20:08 +0530 Subject: [PATCH 24/27] refractor --- ...aset.py => arithmetic_addition_dataset.py} | 16 +-- .../copy_perm/{continous.py => repeat.py} | 0 .../rope/value_pe/experiments/results.ipynb | 98 +++++++++++-------- 3 files changed, 66 insertions(+), 48 deletions(-) rename labml_nn/experiments/{arithmetic_dataset.py => arithmetic_addition_dataset.py} (93%) rename labml_nn/experiments/copy_perm/{continous.py => repeat.py} (100%) diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_addition_dataset.py similarity index 93% rename from labml_nn/experiments/arithmetic_dataset.py rename to labml_nn/experiments/arithmetic_addition_dataset.py index 39c6a5a8..cb2d5896 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_addition_dataset.py @@ -21,7 +21,7 @@ from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch -class ArithmeticDataset(Dataset): +class ArithmeticAdditionDataset(Dataset): """ ## Arithmetic Dataset @@ -135,7 +135,7 @@ def __len__(self): return self.n_sequences -class ArithmeticAutoregression(NLPAutoRegressionConfigs): +class ArithmeticAdditionAutoregression(NLPAutoRegressionConfigs): """ ## Arithmetic Task Experiment Configurations """ @@ -152,7 +152,7 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs): # Number of times to run evaluations per epoch inner_iterations = 4 # Number of tokens in the vocabulary - n_tokens = len(ArithmeticDataset(1, 1, 1).itos) + n_tokens = len(ArithmeticAdditionDataset(1, 1, 1).itos) @torch.no_grad() def sample(self): @@ -167,7 +167,7 @@ def sample(self): return # Create a dataset to generate problems - dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1) + dataset = ArithmeticAdditionDataset(self.seq_len, self.max_digits, 1) # Get a set of problems and answers qa = [dataset.get_qa() for _ in range(self.n_tests)] # Collect the problems only @@ -235,12 +235,12 @@ def sample(self): tracker.save('score', correct / len(results)) -@option(ArithmeticAutoregression.train_loader) -def arithmetic_train_loader(c: ArithmeticAutoregression): +@option(ArithmeticAdditionAutoregression.train_loader) +def arithmetic_train_loader(c: ArithmeticAdditionAutoregression): """ Training data loader """ - return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch), + return DataLoader(ArithmeticAdditionDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch), batch_size=c.batch_size, collate_fn=transpose_batch, num_workers=4) @@ -250,7 +250,7 @@ def _test(): """ Code to test generated problems """ - dataset = ArithmeticDataset(256, 8, 10) + dataset = ArithmeticAdditionDataset(256, 8, 10) print(dataset.decode(dataset.get_packed_math_input())) diff --git a/labml_nn/experiments/copy_perm/continous.py b/labml_nn/experiments/copy_perm/repeat.py similarity index 100% rename from labml_nn/experiments/copy_perm/continous.py rename to labml_nn/experiments/copy_perm/repeat.py diff --git a/labml_nn/transformers/rope/value_pe/experiments/results.ipynb b/labml_nn/transformers/rope/value_pe/experiments/results.ipynb index a910de60..1ea15295 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/results.ipynb +++ b/labml_nn/transformers/rope/value_pe/experiments/results.ipynb @@ -34,28 +34,7 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "1fd1ead4-ecd2-448a-be68-3edf25d32ed7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['accuracy_train', 'loss_train', 'time_loop']" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(ind_rope)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "65ee55a7-5100-40c9-b80c-0ec32d317186", "metadata": {}, "outputs": [ @@ -63,12 +42,12 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ "alt.VConcatChart(...)" ] }, - "execution_count": 5, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -130,23 +109,23 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "b5ef8769-c4a6-4e88-8485-57a9a02da497", "metadata": {}, "outputs": [], "source": [ "d_rope, _ = analytics.indicator_data((ind_rope.loss_train))\n", - "d_rope = [d[:, 5] for d in d_rope]\n", - "d_rope = np.stack(d_rope).mean(axis=0)\n", + "step_rope = np.stack([d[:, 0] for d in d_rope]).mean(axis=0)\n", + "d_rope = np.stack([d[:, 5] for d in d_rope]).mean(axis=0)\n", "\n", "d_roper, _ = analytics.indicator_data((ind_roper.loss_train))\n", - "d_roper = [d[:, 5] for d in d_roper]\n", - "d_roper = np.stack(d_roper).mean(axis=0)" + "step_roper = np.stack([d[:, 0] for d in d_roper]).mean(axis=0)\n", + "d_roper = np.stack([d[:, 5] for d in d_roper]).mean(axis=0)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "e737046d-ffb8-47f0-b742-d7241914061a", "metadata": {}, "outputs": [ @@ -154,12 +133,12 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ "alt.VConcatChart(...)" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "analytics.distribution([d_rope[50:], d_roper[50:]], width=800, height=600, height_minimap=50)" + "offset = 50\n", + "analytics.distribution([d_rope[offset:], d_roper[offset:]], step_rope[offset:], width=800, height=600, height_minimap=50)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "5dd17505-e567-43a0-8a2f-e092ae954183", "metadata": {}, "outputs": [], + "source": [ + "d_rope, _ = analytics.indicator_data((ind_rope.loss_train))\n", + "d_rope = np.stack([d[-5:, 5] for d in d_rope]).mean(axis=1)\n", + "\n", + "d_roper, _ = analytics.indicator_data((ind_roper.loss_train))\n", + "d_roper = np.stack([d[-5:, 5] for d in d_roper]).mean(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "34c8d915-06bd-484f-906c-75abef12afe0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([0.3234671 , 0.34104063, 0.31573243, 0.32310195, 0.31422086,\n", + " 0.36494735, 0.35428971, 0.31074833, 0.32917885, 0.32994516]),\n", + " array([0.32325992, 0.31913995, 0.3186015 , 0.32333733, 0.31854622,\n", + " 0.32299762, 0.31425539, 0.32346853, 0.32077246, 0.31096461]))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d_rope, d_roper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8b2379b-5465-4d4e-bedd-7c9cdde994ae", + "metadata": {}, + "outputs": [], "source": [] } ], From 05b24e212af0049ee1889dbc808d14bf6b88f444 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 9 Jun 2022 15:22:24 +0530 Subject: [PATCH 25/27] refractor --- .../{arithmetic_experiment.py => arithmetic_addition.py} | 4 ++-- .../transformers/rope/value_pe/experiments/copy_repeat.py | 2 +- labml_nn/transformers/rope/value_pe/experiments/experiment.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename labml_nn/transformers/rope/value_pe/experiments/{arithmetic_experiment.py => arithmetic_addition.py} (94%) diff --git a/labml_nn/transformers/rope/value_pe/experiments/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py similarity index 94% rename from labml_nn/transformers/rope/value_pe/experiments/arithmetic_experiment.py rename to labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py index d281e281..d397f415 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py @@ -10,12 +10,12 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression +from labml_nn.experiments.arithmetic_addition_dataset import ArithmeticAdditionAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs -class Configs(RoPEConfigs, ArithmeticAutoregression): +class Configs(RoPEConfigs, ArithmeticAdditionAutoregression): """ We inherit [RoPE experiment](../experiment.html) and use it for [arithmetic addition task](../../experiments/arithmetic_dataset.html). diff --git a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py index 53c31336..d2cd313e 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +++ b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py @@ -10,7 +10,7 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.copy_perm.continous import CopyRepeatAutoregression +from labml_nn.experiments.copy_perm.repeat import CopyRepeatAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs diff --git a/labml_nn/transformers/rope/value_pe/experiments/experiment.py b/labml_nn/transformers/rope/value_pe/experiments/experiment.py index db677a81..fccc582c 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiments/experiment.py @@ -19,7 +19,7 @@ # ### Rotary PE attention -class Configs(RoPEConfigs): # , ArithmeticAutoregression): +class Configs(RoPEConfigs): pass From 0ce0aba12a2ade2c23e984761a0293d6660e8fc4 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 9 Jun 2022 15:24:13 +0530 Subject: [PATCH 26/27] refractor --- labml_nn/experiments/algo_tasks/__init__.py | 0 .../arithmetic_addition.py} | 0 .../experiments/{copy_perm/__init__.py => algo_tasks/copy.py} | 0 .../{copy_perm/repeat.py => algo_tasks/copy_repeat.py} | 0 .../rope/value_pe/experiments/arithmetic_addition.py | 2 +- .../transformers/rope/value_pe/experiments/copy_experiment.py | 2 +- labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py | 2 +- 7 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 labml_nn/experiments/algo_tasks/__init__.py rename labml_nn/experiments/{arithmetic_addition_dataset.py => algo_tasks/arithmetic_addition.py} (100%) rename labml_nn/experiments/{copy_perm/__init__.py => algo_tasks/copy.py} (100%) rename labml_nn/experiments/{copy_perm/repeat.py => algo_tasks/copy_repeat.py} (100%) diff --git a/labml_nn/experiments/algo_tasks/__init__.py b/labml_nn/experiments/algo_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/labml_nn/experiments/arithmetic_addition_dataset.py b/labml_nn/experiments/algo_tasks/arithmetic_addition.py similarity index 100% rename from labml_nn/experiments/arithmetic_addition_dataset.py rename to labml_nn/experiments/algo_tasks/arithmetic_addition.py diff --git a/labml_nn/experiments/copy_perm/__init__.py b/labml_nn/experiments/algo_tasks/copy.py similarity index 100% rename from labml_nn/experiments/copy_perm/__init__.py rename to labml_nn/experiments/algo_tasks/copy.py diff --git a/labml_nn/experiments/copy_perm/repeat.py b/labml_nn/experiments/algo_tasks/copy_repeat.py similarity index 100% rename from labml_nn/experiments/copy_perm/repeat.py rename to labml_nn/experiments/algo_tasks/copy_repeat.py diff --git a/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py index d397f415..9ef756fe 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py +++ b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_addition.py @@ -10,7 +10,7 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.arithmetic_addition_dataset import ArithmeticAdditionAutoregression +from labml_nn.experiments.algo_tasks.arithmetic_addition import ArithmeticAdditionAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs diff --git a/labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py b/labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py index f85e80ed..c2065624 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiments/copy_experiment.py @@ -10,7 +10,7 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.copy_perm import CopyAutoregression +from labml_nn.experiments.algo_tasks.copy import CopyAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs diff --git a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py index d2cd313e..a7da19aa 100644 --- a/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py +++ b/labml_nn/transformers/rope/value_pe/experiments/copy_repeat.py @@ -10,7 +10,7 @@ from labml import experiment from labml.configs import calculate -from labml_nn.experiments.copy_perm.repeat import CopyRepeatAutoregression +from labml_nn.experiments.algo_tasks.copy_repeat import CopyRepeatAutoregression from labml_nn.transformers import TransformerConfigs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs From 092b8ddaf4fb6b986d58a2f7c9e62467d3bcd9c9 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 9 Jun 2022 17:25:02 +0530 Subject: [PATCH 27/27] multiplication --- .../algo_tasks/arithmetic_addition.py | 12 +- .../algo_tasks/arithmetic_multiplication.py | 273 ++++++++++++++++++ .../experiments/arithmetic_multiplication.py | 94 ++++++ 3 files changed, 373 insertions(+), 6 deletions(-) create mode 100644 labml_nn/experiments/algo_tasks/arithmetic_multiplication.py create mode 100644 labml_nn/transformers/rope/value_pe/experiments/arithmetic_multiplication.py diff --git a/labml_nn/experiments/algo_tasks/arithmetic_addition.py b/labml_nn/experiments/algo_tasks/arithmetic_addition.py index cb2d5896..8239be7a 100644 --- a/labml_nn/experiments/algo_tasks/arithmetic_addition.py +++ b/labml_nn/experiments/algo_tasks/arithmetic_addition.py @@ -96,7 +96,7 @@ def get_qa(self): x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) - return f'x={x}+{y};', f'{x + y}' + return f'?x={x}+{y};', f'{x + y}' def get_packed_math_input(self): """ @@ -197,17 +197,17 @@ def sample(self): # Get the model prediction (greedy) output = output[-1].argmax(dim=-1) + # Override with the question + for j, p in enumerate(questions): + if len(p) > i + 1: + output[j] = dataset.stoi[p[i + 1]] + # Find which sequences have finished finished = finished | (output == new_line) # Skip if all have finished if finished.sum() == len(finished): continue - # Override with the question - for j, p in enumerate(questions): - if len(p) > i + 1: - output[j] = dataset.stoi[p[i + 1]] - # Add the next token to the input data = torch.cat([data, output[None, :]], dim=0) diff --git a/labml_nn/experiments/algo_tasks/arithmetic_multiplication.py b/labml_nn/experiments/algo_tasks/arithmetic_multiplication.py new file mode 100644 index 00000000..e8d9ade1 --- /dev/null +++ b/labml_nn/experiments/algo_tasks/arithmetic_multiplication.py @@ -0,0 +1,273 @@ +""" +--- +title: Arithmetic Dataset +summary: > + This creates arithmetic problems. +--- + +*This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).* +""" + +import random +import string +from typing import List + +import torch +from torch.utils.data import DataLoader, Dataset + +from labml import monit, logger, tracker +from labml.configs import option +from labml.logger import Text +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch + + +class ArithmeticMultiplicationDataset(Dataset): + """ + ## Arithmetic Dataset + + This creates arithmetic addition problems and solutions with workings. + We've only implemented addition so far. + + It's based on a character level tokenization. + """ + + def __init__(self, seq_len: int, max_digits: int, base: int, n_sequences: int): + """ + :param seq_len: is the sequence length of generated math problems. + We fill as many problems as possible upto this length + :max_digits: is the maximum number of digits in the operand integers + :n_sequences: is the number of sequences per epoch + """ + self.base = base + self.n_sequences = n_sequences + self.max_digits = max_digits + self.seq_len = seq_len + # Token id to string + self.itos = list(string.digits + 'x =\n?*;') + # Character to token id + self.stoi = {c: i for i, c in enumerate(self.itos)} + + def make_int(self, n_digits: int): + """ + Generates an integer with `n_digit` number of digits + """ + res = 0 + for i in range(n_digits): + d = random.randrange(1, self.base + 1) if i == 0 else random.randrange(0, self.base + 1) + res = res * self.base + d + + return res + + def get_add_explanation(self, x: int, y: int): + """ + Generates the workings for `x + y`. + For example for `11+29` it generates + `1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`. + """ + + explanation = [] + while x > 0: + rx = x % self.base + explanation.append(f"{self.to_string(y * rx)}") + x = x // self.base + + return ' '.join(explanation) + + # Make a problem with a pre_explanation or not + def make_add_problem(self): + """ + Creates an arithmetic addition problem with workings and answer. + """ + x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + + explanation = self.get_add_explanation(x, y) + return f"x={self.to_string(x)}*{self.to_string(y)}; {explanation} x=={self.to_string(x * y)}\n" + + def to_string(self, x: int): + if x == 0: + return '0' + a = [] + while x > 0: + a += [f'{x % self.base}'] + x = x // self.base + + return ''.join(reversed(a)) + + def get_qa(self): + """ + Get arithmetic problem and answer. This is used for evaluation. + """ + x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + + return f'?x={self.to_string(x)}*{self.to_string(y)};', f'{self.to_string(x * y)}' + + def get_packed_math_input(self): + """ + Generate multiple problems and pack them into a sequence. + """ + s_enc = [] + mask = [] + while len(s_enc) <= self.seq_len: + s_part = self.make_add_problem() + s_part_enc = self.encode('?' + s_part) + prob, sol = s_part.split(';') + mask += [False] * (len(prob) + 2) + mask += [True] * len(sol) + s_enc = s_enc + s_part_enc + return s_enc, mask + + def encode(self, s: str): + """ + Encode a given string + """ + return [self.stoi[c] for c in s] + + def decode(self, arr: List[int]): + """ + Decode a list of token ids + """ + return ''.join([self.itos[c] for c in arr]) + + def __getitem__(self, idx: int): + """ + Get a input and target pair for auto-regressive modelling + """ + s, mask = self.get_packed_math_input() + s = torch.tensor(s) + mask = torch.tensor(mask) + target = s * mask + -1 * (~mask) + return s[:self.seq_len], target[1:self.seq_len + 1] + + def __len__(self): + """ + Number of sequences per epoch + """ + return self.n_sequences + + +class ArithmeticMultiplicationAutoregression(NLPAutoRegressionConfigs): + """ + ## Arithmetic Task Experiment Configurations + """ + # Maximum number of digits per operand integer + max_digits: int = 4 + # Number of training sequences per epoch + train_sequences_per_epoch: int = 2 ** 12 + # Training data loader + train_loader: DataLoader = 'arithmetic_train_loader' + # Number of problems in evaluation + n_tests: int = 64 + # No need of a validation dataset + validator = None + # Number of times to run evaluations per epoch + inner_iterations = 4 + # Number of tokens in the vocabulary + base: int = 10 + n_tokens = len(ArithmeticMultiplicationDataset(1, 1, 1, 1).itos) + + @torch.no_grad() + def sample(self): + """ + ### Evaluation + + We use the sampling function to evaluate the model on a set of problems + """ + + # Skip in the first epoch + if self.training_loop.idx < 1: + return + + # Create a dataset to generate problems + dataset = ArithmeticMultiplicationDataset(self.seq_len, self.max_digits, self.base, 1) + # Get a set of problems and answers + qa = [dataset.get_qa() for _ in range(self.n_tests)] + # Collect the problems only + questions = [p[0] for p in qa] + + # Create a tensor with only the initial token + data = torch.tensor([[dataset.stoi[p[0]] for p in questions]]) + # Move to device + data = data.to(self.device) + + # Number of sequences that have completed + finished = torch.zeros((len(questions),)).bool().to(self.device) + # Token id of the new line character - this marks end of the answer + new_line = dataset.stoi['\n'] + + # Sampled results + results = [p[0] for p in questions] + + # Sample upto sequence length + for i in monit.iterate('Sample', self.seq_len - 1): + # If all the sequences have completed we skip this + if finished.sum() == len(finished): + continue + + # Get the model output + output, *_ = self.model(data) + # Get the model prediction (greedy) + output = output[-1].argmax(dim=-1) + + # Override with the question + for j, p in enumerate(questions): + if len(p) > i + 1: + output[j] = dataset.stoi[p[i + 1]] + + # Find which sequences have finished + finished = finished | (output == new_line) + # Skip if all have finished + if finished.sum() == len(finished): + continue + + # Add the next token to the input + data = torch.cat([data, output[None, :]], dim=0) + + # Get the sampled results + for j, c in enumerate(output): + results[j] += dataset.itos[c] + + # Discard everything after the answer in the results + results = [r.split('\n')[0] for r in results] + + # Log a sample + res_sample = results[0].split(';') + logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)]) + + # Get the answers + results = [r.split('x==')[-1] for r in results] + + # Count the number of correct answers + correct = 0 + for r, _qa in zip(results, qa): + if r == _qa[1]: + correct += 1 + + # Log the score + tracker.save('score', correct / len(results)) + + +@option(ArithmeticMultiplicationAutoregression.train_loader) +def arithmetic_train_loader(c: ArithmeticMultiplicationAutoregression): + """ + Training data loader + """ + return DataLoader(ArithmeticMultiplicationDataset(c.seq_len, c.max_digits, c.base, c.train_sequences_per_epoch), + batch_size=c.batch_size, + collate_fn=transpose_batch, + num_workers=4) + + +def _test(): + """ + Code to test generated problems + """ + dataset = ArithmeticMultiplicationDataset(256, 4, 4, 10) + + print(dataset.decode(dataset.get_packed_math_input()[0])) + + +# +if __name__ == '__main__': + _test() diff --git a/labml_nn/transformers/rope/value_pe/experiments/arithmetic_multiplication.py b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_multiplication.py new file mode 100644 index 00000000..f60f5016 --- /dev/null +++ b/labml_nn/transformers/rope/value_pe/experiments/arithmetic_multiplication.py @@ -0,0 +1,94 @@ +""" +--- +title: Rotary Positional Embeddings with Relative distance (RoPER) Experiment +summary: This experiment trains a transformer model with Rotary Positional Embeddings with + Relative Distance (RoPER) on the arithmetic addition task. +--- + +# Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) Experiment +""" + +from labml import experiment +from labml.configs import calculate +from labml_nn.experiments.algo_tasks.arithmetic_multiplication import ArithmeticMultiplicationAutoregression +from labml_nn.transformers import TransformerConfigs +from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs + + +class Configs(RoPEConfigs, ArithmeticMultiplicationAutoregression): + """ + We inherit [RoPE experiment](../experiment.html) and use it for + [arithmetic addition task](../../experiments/arithmetic_dataset.html). + + We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) + below. + """ + pass + + +def _rotary_value_pe_mha(c: TransformerConfigs): + """ + Use Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) in attention. + """ + from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.) + + +# Configuration options +calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha) +calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha) + + +def main(): + # Create experiment + experiment.create(name="roper_mult", comment="4", writers={'screen', 'labml'}) + # Create configs + conf = Configs() + # Override configurations + experiment.configs(conf, { + 'max_digits': 8, + 'base': 4, + + # No fixed positional embeddings + 'transformer.src_embed': 'no_pos', + 'transformer.tgt_embed': 'no_pos', + + # Encoder with RoPER attention + # 'transformer.encoder_attn': 'rotary_value', + # Encoder with RoPE attention + 'transformer.encoder_attn': 'rotary', + + # + 'model': 'rotary_pe_transformer', + + # Use a context size of $256$ + 'seq_len': 512, + # Train for 32 epochs + 'epochs': 20, + # Batch size $4$ + 'batch_size': 16, + + # Model size + 'd_model': 128, + 'transformer.ffn.d_ff': 512, + 'transformer.n_heads': 4, + 'transformer.dropout': 0.0, + + # Use [Adam optimizer](../../optimizers/noam.html) + 'optimizer.optimizer': 'Adam', + 'optimizer.learning_rate': 2.5e-4, + }) + + # Set models for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # Run training + conf.run() + + +# +if __name__ == '__main__': + main()