ksuit.models.single¶
Classes¶
Base class for all neural network modules. |
Module Contents¶
- class ksuit.models.single.Model(model_config, is_frozen=False, update_counter=None, path_provider=None, data_container=None, static_context=None)¶
Bases:
ksuit.models.model_base.ModelBaseBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
model_config (ksuit.schemas.models.ModelBaseConfig)
is_frozen (bool)
update_counter (ksuit.utils.training.UpdateCounter | None)
path_provider (ksuit.providers.path_provider.PathProvider | None)
data_container (ksuit.utils.data.data_container.DataContainer | None)
Base class for single models, i.e. one model with one optimizer as opposed to CompositeModel.
- Parameters:
model_config (ksuit.schemas.models.ModelBaseConfig) – Model configuration.
is_frozen (bool) – If true, will set requires_grad of all parameters to false. Will also put the model into eval mode (e.g., to put Dropout or BatchNorm into eval mode).
path_provider (ksuit.providers.path_provider.PathProvider | None) – A path provider used by the initializer to store or retrieve checkpoints.
data_container (ksuit.utils.data.data_container.DataContainer | None) – The data container which includes the data and dataloader. This is currently unused but helpful for quick prototyping only, evaluating forward in debug mode, etc.
static_context (dict[str, Any] | None) – The static context used to pass information between submodules, e.g. patch_size, latent_dim.
update_counter (ksuit.utils.training.UpdateCounter | None)
- property device: torch.device¶
- Return type:
torch.device
- get_named_models()¶
Returns a dict of {model_name: model}, e.g., to log all learning rates of all models/submodels.
- Return type:
- initialize_weights()¶
Freezes the weights of the model by setting requires_grad to False if self.is_frozen is True.
- Return type:
Self
- apply_initializers()¶
Apply the initializers to the model, calling initializer.init_weights and initializer.init_optim.
- Return type:
Self
- initialize_optimizer()¶
Initialize the optimizer.
- Return type:
None
- optimizer_step(grad_scaler)¶
Perform an optimization step.
- Parameters:
grad_scaler (torch.cuda.amp.GradScaler | None)
- Return type:
None
- optimizer_schedule_step()¶
Perform the optimizer learning rate scheduler step.
- Return type:
None
- optimizer_zero_grad(set_to_none=True)¶
Zero the gradients of the optimizer.
- Parameters:
set_to_none (bool)
- Return type:
None