emmi_inference.models.loader

Functions

open_checkpoint_file(checkpoint_path[, device, ...])

Open a checkpoint file (regular PyTorch file).

load_model(checkpoint, *[, model_factory, device, ...])

Load a PyTorch model from a checkpoint.

Module Contents

emmi_inference.models.loader.open_checkpoint_file(checkpoint_path, device='cpu', weights_only=False)

Open a checkpoint file (regular PyTorch file).

Parameters:
  • checkpoint_path (pathlib.Path) – Path to a checkpoint file.

  • device (str) – Device mapping for loading (e.g. “cpu”, “cuda”).

  • weights_only (bool) – If True, open the checkpoint file only for primitive types (dicts, tensors, etc). Defaults to False.

Returns:

  • Dictionary with checkpoint keys and arbitrary values (e.g. state_dict, metadata, etc.).

Return type:

dict[str, Any]

emmi_inference.models.loader.load_model(checkpoint, *, model_factory=None, device='cpu', weights_only=False)

Load a PyTorch model from a checkpoint.

Parameters:
  • checkpoint (pathlib.Path | dict[str, Any]) – Path to a checkpoint file or a constructor dictionary with relevant fields and values.

  • model_factory (emmi_inference.models.registry.ModelFactory | None)

  • device (str)

  • weights_only (bool)

Returns:

A pytorch model.

Raises:

TypeError – in cased of failed state_dict loading into the model.

Return type:

torch.nn.Module