|
1 | 1 | from collections import OrderedDict
|
2 | 2 | from abc import abstractmethod
|
3 | 3 | from typing import Optional, Dict, Any, List
|
| 4 | +import os |
4 | 5 | import torch
|
5 | 6 | from torch.optim.optimizer import Optimizer
|
6 | 7 | from torch.optim.lr_scheduler import _LRScheduler
|
@@ -133,6 +134,24 @@ def set_input(self, input, device):
|
133 | 134 | """
|
134 | 135 | raise NotImplementedError
|
135 | 136 |
|
| 137 | + def load_state_dict_with_same_shape(self, weights, strict=False): |
| 138 | + model_state = self.state_dict() |
| 139 | + filtered_weights = {k: v for k, v in weights.items() if k in model_state and v.size() == model_state[k].size()} |
| 140 | + log.info("Loading weights:" + ", ".join(filtered_weights.keys())) |
| 141 | + self.load_state_dict(filtered_weights, strict=strict) |
| 142 | + |
| 143 | + def set_pretrained_weights(self): |
| 144 | + path_pretrained = getattr(self.opt, "path_pretrained", None) |
| 145 | + weight_name = getattr(self.opt, "weight_name", "latest") |
| 146 | + |
| 147 | + if path_pretrained is not None: |
| 148 | + if not os.path.exists(path_pretrained): |
| 149 | + log.warning("The path does not exist, it will not load any model") |
| 150 | + else: |
| 151 | + log.info("load pretrained weights from {}".format(path_pretrained)) |
| 152 | + m = torch.load(path_pretrained)["models"][weight_name] |
| 153 | + self.load_state_dict_with_same_shape(m, strict=False) |
| 154 | + |
136 | 155 | def get_labels(self):
|
137 | 156 | """ returns a trensor of size ``[N_points]`` where each value is the label of a point
|
138 | 157 | """
|
|
0 commit comments