emmi_inference.runner

Attributes

Classes

InferencePayload

A container class that hold relevant information about the inference run over a single batch.

InferenceRunner

Thin wrapper around a model plus an execution context.

Functions

compute_metrics(predictions, targets, ...[, target_suffix])

Returns a computed metrics dict from predictions and target batch (typical surface-volume-like data).

Module Contents

emmi_inference.runner.InferenceStatus
emmi_inference.runner.INFERENCE_ENGINE_VERSION = 'v1'
class emmi_inference.runner.InferencePayload

A container class that hold relevant information about the inference run over a single batch.

status: InferenceStatus = 'ok'

Indicator if inference was successful or not.

error: str | None = None

Detailed error message in case of failure.

inputs: dict[str, Any]

Input tensors used in the forward pass.

outputs: dict[str, Any]

Model outputs.

debug: dict[str, Any]

Debugging information about the inference run, like timings, misc, etc.

meta: dict[str, Any]

Meta information about the inference run, like model class, inference engine, etc.

dropped: dict[str, Any]

A collection of dropped tensors (not used in the forward pass or elsewhere).

kept: dict[str, Any]

A collection of kept tensors (not used in the forward pass but needed elsewhere).

get(key, default=None)
Parameters:
  • key (str)

  • default (Any)

Return type:

Any

to_cpu(*, detach=True, numpy=False, float32=False)

Return a shallow-cloned payload with all tensors moved to CPU (optionally numpy).

Parameters:
Return type:

InferencePayload

summary()

Lightweight, JSON-friendly shapes/dtypes (no tensor payloads).

Return type:

dict[str, Any]

class emmi_inference.runner.InferenceRunner(model, *, device, autocast, dtype=None, preprocessor=None)

Thin wrapper around a model plus an execution context.

Encapsulates device/dtype/autocast configuration and an optional preprocessor. Use run() to execute inference on a single tensor or a mapping of tensors.

Parameters:
  • model (torch.nn.Module) – The instantiated torch.nn.Module.

  • device (str) – Target device string (e.g., “cpu”, “cuda”, “mps”).

  • autocast (bool) – Whether to enable automatic mixed precision (when supported) during forward.

  • dtype (torch.dtype | None) – Optional torch.dtype to cast inputs (and context) to before forward.

  • preprocessor (object | None) – Optional callable applied to inputs before device/dtype casting.

model
device_context
preprocessor = None
run(x, batch_simplification=None, batch_keys_to_keep=None, batch_keys_to_drop=None)

Execute a forward pass using the configured context.

Parameters:
  • x (torch.Tensor | dict[str, torch.Tensor]) – Input tensor or dict of tensors.

  • batch_simplification (dict[str, int] | None) – Optional dictionary of batch simplifications with keys and reduction numbers.

  • batch_keys_to_keep (set[str] | None) – Optional set of batch keys to keep.

  • batch_keys_to_drop (set[str] | None) – Optional set of batch keys to drop.

Returns:

The model outputs, as returned by the underlying module.

Return type:

InferencePayload

emmi_inference.runner.compute_metrics(predictions, targets, dataset_normalizers, evaluation_modes, target_suffix='_target')

Returns a computed metrics dict from predictions and target batch (typical surface-volume-like data).

Parameters:
  • predictions (dict) – Input predictions.

  • targets (dict) – Input batch with ground truth targets.

  • dataset_normalizers (dict) – Dataset normalization configurations.

  • evaluation_modes (list[str]) – A list of evaluation modes. Defaults to [“surface_pressure”, “volume_velocity”], if empty.

  • target_suffix (str) – Target suffix added to all input tensors. Defaults to “_target”.

Returns:

Dictionary of computed metrics. Metrics names as keys and tensors as values.

Return type:

dict