|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -__version__ = "2024.10.0" |
| 15 | +__version__ = "2024.10.1" |
16 | 16 |
|
17 | 17 | __all__ = [ |
18 | 18 | "prepare_model_for_kbit_training", |
|
43 | 43 | "accelerate_new_send_to_device", |
44 | 44 | "patch_gradient_checkpointing", |
45 | 45 | "unpatch_gradient_checkpointing", |
| 46 | + "patch_gradient_accumulation_fix", |
46 | 47 | ] |
47 | 48 |
|
48 | 49 | import torch |
@@ -1138,3 +1139,63 @@ def test_mask_creation(): |
1138 | 1139 | assert(torch.all(correct_mask == our_mask)) |
1139 | 1140 | pass |
1140 | 1141 | pass |
| 1142 | + |
| 1143 | + |
| 1144 | +def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): |
| 1145 | + batch_samples = [] |
| 1146 | + num_items_in_batch = None |
| 1147 | + for _ in range(num_batches): |
| 1148 | + try: |
| 1149 | + batch_samples += [next(epoch_iterator)] |
| 1150 | + except StopIteration: |
| 1151 | + break |
| 1152 | + if len(batch_samples) > 0 and "labels" in batch_samples[0]: |
| 1153 | + try: |
| 1154 | + num_items_in_batch = sum( |
| 1155 | + [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] |
| 1156 | + ) |
| 1157 | + except TypeError: |
| 1158 | + pass |
| 1159 | + return batch_samples, num_items_in_batch |
| 1160 | +pass |
| 1161 | + |
| 1162 | + |
| 1163 | +def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): |
| 1164 | + if "num_items_in_batch" in kwargs: |
| 1165 | + if "num_items_in_batch" not in inputs: |
| 1166 | + inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] |
| 1167 | + pass |
| 1168 | + pass |
| 1169 | + return self._old_compute_loss(model, inputs, args, kwargs) |
| 1170 | +pass |
| 1171 | + |
| 1172 | + |
| 1173 | +def patch_gradient_accumulation_fix(Trainer): |
| 1174 | + # Fixes gradient accumulation |
| 1175 | + if hasattr(Trainer, "get_batch_samples"): |
| 1176 | + from inspect import getsource |
| 1177 | + if \ |
| 1178 | + not getsource(Trainer.get_batch_samples).strip()\ |
| 1179 | + .endswith("return batch_samples, num_items_in_batch"): |
| 1180 | + |
| 1181 | + raise NotImplementedError("Unsloth: Please make a Github issue immediately!!") |
| 1182 | + else: |
| 1183 | + if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples": |
| 1184 | + Trainer.get_batch_samples = _unsloth_get_batch_samples |
| 1185 | + pass |
| 1186 | + |
| 1187 | + # Also fix passing in num_items_in_batch |
| 1188 | + if not hasattr(Trainer, "_old_compute_loss"): |
| 1189 | + Trainer._old_compute_loss = Trainer.compute_loss |
| 1190 | + Trainer.compute_loss = _unsloth_pre_compute_loss |
| 1191 | + pass |
| 1192 | + pass |
| 1193 | + else: |
| 1194 | + logger.warning_once( |
| 1195 | + "Unsloth: We fixed a gradient accumulation bug, "\ |
| 1196 | + "but it seems like you don't have the latest transformers version!\n"\ |
| 1197 | + "Please update transformers via:\n"\ |
| 1198 | + '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' |
| 1199 | + ) |
| 1200 | + pass |
| 1201 | +pass |
0 commit comments