Skip to content

updated ddp_pipeline #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions advanced_source/ddp_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,24 @@ def run_worker(rank, world_size):
def print_with_rank(msg):
print('[RANK {}]: {}'.format(rank, msg))

import io
from torchtext.utils import download_from_url, extract_archive
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=".data{}".format(rank)))
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer,
iter(io.open(train_filepath,
encoding="utf8"))))
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def data_process(raw_text_iter):
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
dtype=torch.long) for item in raw_text_iter]
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device(2 * rank)

def batchify(data, bsz, rank, world_size, is_train=False):
Expand Down Expand Up @@ -264,7 +262,7 @@ def get_batch(source, i):
# another across GPUs 2 and 3. Both pipes are then replicated using DistributedDataParallel.

# In 'run_worker'
ntokens = len(vocab.stoi) # the size of vocabulary
ntokens = len(vocab) # the size of vocabulary
emsize = 4096 # embedding dimension
nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 8 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
Expand Down Expand Up @@ -361,7 +359,7 @@ def train():
model.train() # Turn on the train mode
total_loss = 0.
start_time = time.time()
ntokens = len(vocab.stoi)
ntokens = len(vocab)

# Train only for 50 batches to keep script execution time low.
nbatches = min(50 * bptt, train_data.size(0) - 1)
Expand All @@ -388,7 +386,7 @@ def train():
print_with_rank('| epoch {:3d} | {:5d}/{:5d} batches | '
'lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, nbatches // bptt, scheduler.get_lr()[0],
epoch, batch, nbatches // bptt, scheduler.get_last_lr()[0],
elapsed * 1000 / log_interval,
cur_loss, math.exp(cur_loss)))
total_loss = 0
Expand All @@ -397,7 +395,7 @@ def train():
def evaluate(eval_model, data_source):
eval_model.eval() # Turn on the evaluation mode
total_loss = 0.
ntokens = len(vocab.stoi)
ntokens = len(vocab)
# Evaluate only for 50 batches to keep script execution time low.
nbatches = min(50 * bptt, data_source.size(0) - 1)
with torch.no_grad():
Expand Down Expand Up @@ -455,8 +453,6 @@ def evaluate(eval_model, data_source):
if __name__=="__main__":
world_size = 2
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)


######################################################################
# Output
# ------
Expand All @@ -466,52 +462,52 @@ def evaluate(eval_model, data_source):
######################################################################
#.. code-block:: py
#
# [RANK 1]: Total parameters in model: 1,041,453,167
# [RANK 0]: Total parameters in model: 1,041,453,167
# [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.18 | loss 48.70 | ppl 1406154472673147092992.00
# [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.42 | loss 48.49 | ppl 1146707511057334927360.00
# [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 42.74 | ppl 3648812398518492672.00
# [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 41.51 | ppl 1064844757565813248.00
# [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 41.85 | ppl 1497706388552644096.00
# [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 40.46 | ppl 373830103285747072.00
# [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.76 | ppl 185159839078666368.00
# [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.89 | ppl 211756997625874912.00
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | end of epoch 1 | time: 69.37s | valid loss 2.92 | valid ppl 18.46
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.97 | loss 43.31 | ppl 6432469059895903232.00
# [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.90 | loss 44.50 | ppl 21245447128217366528.00
# [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.89 | loss 44.50 | ppl 21176949187407757312.00
# [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.87 | loss 44.62 | ppl 23975861229620961280.00
# [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.86 | loss 41.62 | ppl 1193312915629888256.00
# [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.87 | loss 40.69 | ppl 471605759847546240.00
# [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.34 | loss 45.20 | ppl 42812308420836458496.00
# [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.33 | loss 45.68 | ppl 68839569686012223488.00
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 1]: | end of epoch 1 | time: 69.39s | valid loss 2.92 | valid ppl 18.46
# [RANK 1]: | end of epoch 1 | time: 40.08s | valid loss 0.80 | valid ppl 2.22
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1373.91 | loss 39.77 | ppl 187532281612905856.00
# [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1375.62 | loss 39.05 | ppl 91344349371016336.00
# [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.62 | ppl 19917977906884.78
# [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.48 | ppl 17250186491252.32
# [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.14 | ppl 4534527326854.47
# [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.43 | ppl 6035762659681.65
# [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.54 | loss 23.11 | ppl 10869828323.89
# [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.55 | loss 22.90 | ppl 8785318464.24
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | end of epoch 2 | time: 69.02s | valid loss 0.94 | valid ppl 2.55
# [RANK 0]: | end of epoch 1 | time: 40.09s | valid loss 0.80 | valid ppl 2.22
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 768.51 | loss 36.34 | ppl 6063529544668166.00
# [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 769.23 | loss 37.41 | ppl 17651211266236086.00
# [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.57 | loss 28.97 | ppl 3798441739584.11
# [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.56 | loss 29.28 | ppl 5203636967575.47
# [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.04 | loss 28.43 | ppl 2212498693571.25
# [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.05 | loss 28.33 | ppl 2015144761281.48
# [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.10 | loss 23.30 | ppl 13121380184.92
# [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.09 | loss 23.41 | ppl 14653799192.87
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | end of epoch 2 | time: 39.97s | valid loss 0.24 | valid ppl 1.27
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 1]: | end of epoch 2 | time: 69.05s | valid loss 0.94 | valid ppl 2.55
# [RANK 1]: | end of epoch 2 | time: 39.98s | valid loss 0.24 | valid ppl 1.27
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1380.66 | loss 12.98 | ppl 434052.59
# [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1376.47 | loss 12.92 | ppl 410203.33
# [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.80 | ppl 18034.58
# [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.78 | ppl 17741.88
# [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.89 | loss 10.37 | ppl 32016.45
# [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.90 | loss 10.46 | ppl 34735.08
# [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.70 | loss 10.09 | ppl 24147.61
# [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.71 | loss 10.08 | ppl 23748.31
# [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 769.36 | loss 12.80 | ppl 361681.11
# [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 768.97 | loss 12.57 | ppl 287876.61
# [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.27 | loss 12.01 | ppl 164364.60
# [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.30 | loss 11.98 | ppl 159095.89
# [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.75 | loss 10.90 | ppl 54261.91
# [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.72 | loss 10.89 | ppl 53372.39
# [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.49 | loss 10.78 | ppl 47948.35
# [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.50 | loss 10.79 | ppl 48664.42
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 0]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
# [RANK 0]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
# [RANK 0]: -----------------------------------------------------------------------------------------
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 1]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
# [RANK 1]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
# [RANK 1]: -----------------------------------------------------------------------------------------
# [RANK 0]: =========================================================================================
# [RANK 0]: | End of training | test loss 0.60 | test ppl 1.83
# [RANK 0]: | End of training | test loss 0.33 | test ppl 1.39
# [RANK 0]: =========================================================================================
# [RANK 1]: =========================================================================================
# [RANK 1]: | End of training | test loss 0.60 | test ppl 1.83
# [RANK 1]: | End of training | test loss 0.33 | test ppl 1.39
# [RANK 1]: =========================================================================================
#