emmi.modules.attention.utils

Functions

apply_init_method(module, proj_weight, init_method)

Apply an initialization function to all applicable sub-modules of a given module.

Module Contents

emmi.modules.attention.utils.apply_init_method(module, proj_weight, init_method)

Apply an initialization function to all applicable sub-modules of a given module.

Parameters:
  • module (torch.nn.Module) – The nn.Module instance to initialize.

  • init_fn – The initialization function to apply to each sub-module.

  • proj_weight (torch.Tensor)

  • init_method (str)

Return type:

None