Skip to content

Checkpoints

Functions:

get_auto_model_class(model_class)

Get the AutoModel class corresponding to the model class.

Source code in tapeagents/finetune/checkpoints.py
24
25
26
27
28
29
30
31
32
33
34
def get_auto_model_class(
    model_class: ModelClass,
) -> Type[_BaseAutoModelClass]:
    """Get the AutoModel class corresponding to the model class."""
    match model_class:
        case "causal-language-modeling":
            return AutoModelForCausalLM
        case "seq2seq-language-modeling":
            return AutoModelForSeq2SeqLM
        case _:
            raise ValueError(f"Unsupported model class: {model_class}")

get_temporary_folder_and_move(output_dir)

Context manager safe checkpointing.

Creates temporary folder ~output_dir, then rename to final destination

Source code in tapeagents/finetune/checkpoints.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
@contextlib.contextmanager
def get_temporary_folder_and_move(output_dir: Path):
    """
    Context manager safe checkpointing.

    Creates temporary folder `~output_dir`, then rename to final destination
    """
    if os.path.exists(output_dir) and not os.path.isdir(output_dir):
        raise ValueError("get_temporary_folder_and_move: output_dir is not a directory")

    output_dir = output_dir.resolve()
    temporary_path = output_dir.parent / ("~" + output_dir.name)

    if accelerator.is_main_process:
        if os.path.exists(temporary_path):
            logger.info(f"Deleting temporary directory {temporary_path}")
            shutil.rmtree(temporary_path)
        logger.info(f"Creating temporary directory {temporary_path}")
        os.makedirs(temporary_path)

    accelerator.wait_for_everyone()
    yield temporary_path
    accelerator.wait_for_everyone()

    # Move to final path
    if accelerator.is_main_process:
        # delete output_dir if it exists
        if os.path.exists(output_dir):
            logger.info(
                f" -> Deleting {output_dir}. "
                f"If this fails, manually delete it and move {temporary_path} to {output_dir}"
            )
            shutil.rmtree(output_dir)
        logger.info(f" -> Renaming {temporary_path} to {output_dir}")
        os.rename(temporary_path, output_dir)
        logger.info(f"Done moving files to {output_dir}")

load_training_checkpoint(training_state_dir, model, optimizer, lr_scheduler)

Load checkpoint created by save_training_checkpoint() in-place:

  • With deepspeed, this will load model, optimizer, lr_scheduler states in-place.
  • Without deepspeed, this will only load optimizer, lr_scheduler states in-place, but not model states!
Source code in tapeagents/finetune/checkpoints.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def load_training_checkpoint(
    training_state_dir: Path,
    model: transformers.PreTrainedModel,
    optimizer,
    lr_scheduler,
):
    """
    Load checkpoint created by save_training_checkpoint() in-place:

    - With deepspeed, this will load model, optimizer, lr_scheduler states in-place.
    - Without deepspeed, this will *only* load optimizer, lr_scheduler states in-place,
        but *not* model states!
    """
    assert (
        not os.path.exists(training_state_dir) or training_state_dir.is_dir()
    ), f"output_dir {training_state_dir} must be a directory"

    if model.__class__.__name__.endswith("DeepSpeedEngine"):
        logger.info("Load deepspeed training state")
        # This magically loads optimizer and lr_scheduler states (if they were saved)
        # (the passed optimizer and lr_scheduler arguments will be ignored)
        load_path, extra_training_state = model.load_checkpoint(
            training_state_dir,
            tag="deepspeed",
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        if load_path is None:
            raise RuntimeError(f"Loading deepspeed checkpoint from {training_state_dir} failed")
        if (
            model.lr_scheduler is None
            and extra_training_state is not None
            and "lr_scheduler_state" in extra_training_state
        ):
            # Manually load lr_scheduler states
            logger.warning(f"Manually loading ds-unsupported lr_scheduler of type {type(lr_scheduler).__name__}")
            lr_scheduler.load_state_dict(extra_training_state["lr_scheduler_state"])
        logger.info(f"Loaded deepspeed checkpoint from {training_state_dir}")
    else:  # multi_gpu (no deepspeed)
        # This needs to be called from all processes
        training_state = torch.load(training_state_dir / "training_state.pt", map_location="cpu")
        optimizer.load_state_dict(training_state["optimizer_state"])
        lr_scheduler.load_state_dict(training_state["lr_scheduler_state"])
        del training_state["optimizer_state"]
        del training_state["lr_scheduler_state"]
        extra_training_state = training_state
        logger.info(f"Loaded accelerate checkpoint from {training_state_dir}")
    return extra_training_state

save_model_only(output_dir, model, unwrap=True, lora=False, safe_serialization=False)

Save model weights and config.

Creates the following files in output_dir/ : - config.json and either: - pytorch_model.bin (single-file model), OR - pytorch_model-XXXXX-of-XXXXX.bin (multi-file model) and pytorch_model.bin.index.json

Note that this does not save optimizer, lr_scheduler, scaler, etc. Use only for later JGA evaluation, not for resuming training

Must be called on all accelerate processes because all of them must save their shards.

Source code in tapeagents/finetune/checkpoints.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def save_model_only(
    output_dir: Path,
    model: transformers.PreTrainedModel,
    unwrap: bool = True,
    lora: bool = False,
    safe_serialization: bool = False,
):
    """
    Save model weights and config.

    Creates the following files in output_dir/ :
        - config.json
    and either:
        - pytorch_model.bin (single-file model), OR
        - pytorch_model-XXXXX-of-XXXXX.bin (multi-file model) and pytorch_model.bin.index.json

    Note that this does not save optimizer, lr_scheduler, scaler, etc.
    Use only for later JGA evaluation, not for resuming training

    Must be called on *all* accelerate processes because all of them must save their shards.
    """
    assert not os.path.exists(output_dir) or output_dir.is_dir(), f"output_dir {output_dir} must be a directory"
    accelerator.wait_for_everyone()

    logger.info(f"Save model to {output_dir}")

    unwrapped_model = accelerator.unwrap_model(model) if unwrap else model
    if lora:
        lora_save(output_dir, unwrapped_model)
        return

    if unwrapped_model.__class__.__name__.endswith("DeepSpeedEngine"):
        unwrapped_model.save_checkpoint(
            save_dir=output_dir,
        )
        logger.info(f"Saved deepspeed checkpoint to {output_dir}")
    elif isinstance(unwrapped_model, transformers.PreTrainedModel):
        unwrapped_model.save_pretrained(  # type: ignore
            output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=accelerator.get_state_dict(model),
            safe_serialization=safe_serialization,
        )
        logger.info(f"Saved model to {output_dir}")
    else:
        raise ValueError(f"model is neither a deepspeed model nor a transformers.PreTrainedModel: {type(model)}")

save_tokenizer_only(output_dir, tokenizer)

Save only tokenizer to output_dir

Can be called on all processes.

Source code in tapeagents/finetune/checkpoints.py
345
346
347
348
349
350
351
352
353
354
355
356
357
def save_tokenizer_only(
    output_dir: Path,
    tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast,
):
    """
    Save only tokenizer to output_dir

    Can be called on *all* processes.
    """
    assert not os.path.exists(output_dir) or output_dir.is_dir(), f"output_dir {output_dir} must be a directory"
    if accelerator.is_main_process:
        logger.info(f"Save tokenizer to {output_dir}")
        tokenizer.save_pretrained(output_dir)