File tree 1 file changed +4
-1
lines changed
1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 3
3
import logging
4
4
import os
5
5
import time
6
+ import warnings
6
7
from collections import defaultdict
7
8
from pathlib import Path
8
9
from typing import Dict , List , Tuple
9
10
10
11
import numpy as np
11
12
import pytorch_lightning as pl
12
13
import torch
14
+ from packaging import version
13
15
from torch .utils .data import DataLoader
14
16
15
17
from lightning_base import BaseTransformer , add_generic_args , generic_train
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
354
356
model : SummarizationModule = SummarizationModule (args )
355
357
else :
356
358
model : SummarizationModule = TranslationModule (args )
357
-
359
+ if version .parse (torch .__version__ ) == version .parse ("1.6" ) and args .fp16 :
360
+ warnings .warn ("FP16 only seems to work with torch 1.5+apex" )
358
361
dataset = Path (args .data_dir ).name
359
362
if (
360
363
args .logger_name == "default"
You can’t perform that action at this time.
0 commit comments