ab_upt ====== .. py:module:: ab_upt Classes ------- .. autoapisummary:: ab_upt.AnchoredBranchedUPT Module Contents --------------- .. py:class:: AnchoredBranchedUPT(config) Bases: :py:obj:`torch.nn.Module` Implementation of the Anchored Branched UPT model. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: data_specs .. py:attribute:: rope .. py:attribute:: encoder .. py:attribute:: geometry_blocks .. py:attribute:: pos_embed .. py:attribute:: surface_bias .. py:attribute:: volume_bias .. py:attribute:: num_perceivers :value: 0 .. py:attribute:: physics_blocks .. py:attribute:: surface_decoder_blocks .. py:attribute:: volume_decoder_blocks .. py:attribute:: surface_decoder .. py:attribute:: volume_decoder .. py:method:: geometry_branch_forward(geometry_position, geometry_supernode_idx, geometry_batch_idx, condition, geometry_attn_kwargs) Forward pass through the geometry branch of the model. .. py:method:: physics_blocks_forward(surface_position_all, volume_position_all, geometry_encoding, physics_token_specs, physics_attn_kwargs, physics_perceiver_attn_kwargs, condition) Forward pass through the physics blocks of the model. Allthough in the AB-UPT paper we only have a perceiver block a the first block, it is possible to have more perceiver blocks in the physics blocks that attend to the geometry encoding. :param surface_position_all: Tensor of shape (B, N_surface_total, D_pos) :param volume_position_all: Tensor of shape (B, N_volume_total, D_pos) :param geometry_encoding: Tensor of shape (B, N_supernodes, D_hidden) :param physics_token_specs: List of TokenSpec defining the token specifications for the physics blocks. :param physics_attn_kwargs: Additional attention kwargs for the physics transformer blocks. :param physics_perceiver_attn_kwargs: Additional attention kwargs for the physics perceiver blocks. :param condition: Optional conditioning tensor of shape (B, D_condition) .. py:method:: decoder_blocks_forward(x_physics, physics_token_specs, surface_token_specs, volume_token_specs, surface_position_all, volume_position_all, surface_decoder_attn_kwargs, volume_decoder_attn_kwargs, condition) Forward pass through the decoder blocks of the model. We have a separate decoder for surface and volume tokens. .. py:method:: create_rope_frequencies(geometry_position, geometry_supernode_idx, surface_position_all, volume_position_all) Create RoPE frequencies for all relevant positions. :param geometry_position: Tensor of shape (B * N_geometry, D_pos), sparse tensor. :param geometry_supernode_idx: Tensor of shape (B * number of super nodes,) with indices of supernodes :param surface_position_all: Tensor of shape (B, N_surface_total, D_pos) :param volume_position_all: Tensor of shape (B, N_volume_total, D_pos) .. py:method:: forward(geometry_position, geometry_supernode_idx, geometry_batch_idx, surface_anchor_position, volume_anchor_position, geometry_design_parameters = None, inflow_design_parameters = None, query_surface_position = None, query_volume_position = None) Forward pass of the AB-UPT model. # TODO: when writing the docs, we have to use consistent notation. :param geometry_position: Coordinates of the geometry mesh. Tensor of shape (B * N_geometry, D_pos), sparse tensor :param geometry_supernode_idx: Indices of the supernodes for the geometry points. Tensor of shape (B * number of super nodes,) :param geometry_batch_idx: Batch indices for the geometry points. Tensor of shape (B * N_geometry,). If None, assumes all points belong to the same batch. :param surface_anchor_position: Coordinates of the surface anchor points. Tensor of shape (B, N_surface_anchor, D_pos) :param volume_anchor_position: Coordinates of the volume anchor points. Tensor of shape (B, N_volume_anchor, D_pos) :param geometry_design_parameters: Design parameters related to the geometry to condition on. Tensor of shape (B, D_geom) :param inflow_design_parameters: Design parameters related to the inflow to condition on. Tensor of shape (B, D_inflow). :param query_surface_position: Coordinates of the query surface points. :param query_volume_position: Coordinates of the query volume points.