diff --git a/improved_diffusion/fp16_util.py b/improved_diffusion/fp16_util.py index 23e0418153..fa9fcab60c 100644 --- a/improved_diffusion/fp16_util.py +++ b/improved_diffusion/fp16_util.py @@ -65,7 +65,7 @@ def unflatten_master_params(model_params, master_params): """ Unflatten the master parameters to look like model_params. """ - return _unflatten_dense_tensors(master_params[0].detach(), model_params) + return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params)) def zero_grad(model_params):