Skip to content

Commit ce37be9

Browse files
authored
[s2s] warn if --fp16 for torch 1.6 (#6977)
1 parent f72fe1f commit ce37be9

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

Diff for: examples/seq2seq/finetune.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import logging
44
import os
55
import time
6+
import warnings
67
from collections import defaultdict
78
from pathlib import Path
89
from typing import Dict, List, Tuple
910

1011
import numpy as np
1112
import pytorch_lightning as pl
1213
import torch
14+
from packaging import version
1315
from torch.utils.data import DataLoader
1416

1517
from lightning_base import BaseTransformer, add_generic_args, generic_train
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
354356
model: SummarizationModule = SummarizationModule(args)
355357
else:
356358
model: SummarizationModule = TranslationModule(args)
357-
359+
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
360+
warnings.warn("FP16 only seems to work with torch 1.5+apex")
358361
dataset = Path(args.data_dir).name
359362
if (
360363
args.logger_name == "default"

0 commit comments

Comments
 (0)