emmi_inference.models.pipelines.field_decoder

Classes

FieldDecoderCollator

Collates a field to be used in a UPT-style decoder. It requires:

Module Contents

class emmi_inference.models.pipelines.field_decoder.FieldDecoderCollator(position_item, target_items, optional=False)

Collates a field to be used in a UPT-style decoder. It requires: - Positions as dense tensor (used as query for the Perceiver decoder) - Targets as sparse tensor (e.g., used for calculating a loss) - Unbatch mask to convert the dense output of the Perceiver decoder into a sparse tensor to compare it to the

targets (i.e., calculate a loss)

Initializes the FieldDecoderCollator.

Parameters:
  • position_item (str) – Identifier for the position.

  • target_items (list[str]) – Identifiers for the position, can use multiple target_items if multiple values are predicted with

  • decoder (the same)

  • optional (bool)

position_item
target_items
optional = False