neurotorch.trainers package¶
Submodules¶
neurotorch.trainers.classification module¶
- class neurotorch.trainers.classification.ClassificationTrainer(*args, **kwargs)¶
Bases:
Trainer
- __init__(*args, **kwargs)¶
Constructor for Trainer.
- Parameters:
model – Model to train.
criterion – Loss function(s) to use. Deprecated, use learning_algorithm instead.
regularization –
Regularization(s) to use. In NeuroTorch, there are two ways to do regularization: 1. Regularization can be specified in the layers with the ‘update_regularization_loss’ method. This regularization will be performed by the same optimizer as the main loss. This way is useful when you want a regularization that depends on the model output or hidden state. 2. Regularization can be specified in the trainer with the ‘regularization’ parameter. This regularization will be performed by a separate optimizer named ‘regularization_optimizer’. This way is useful when you want a regularization that depends only on the model parameters and when you want to control the learning rate of the regularization independently of the main loss.
- Note: This parameter will be deprecated and remove in a future version. The regularization will be
specified in the learning algorithm and/or in the callbacks.
optimizer – Optimizer to use for the main loss. Deprecated. Use learning_algorithm instead.
learning_algorithm – Learning algorithm to use for the main loss. This learning algorithm can be given in the callbacks list as well. If specified, this learning algorithm will be added to the callbacks list. In this case, make sure that the learning algorithm is not added twice. Note that multiple learning algorithms can be used in the callbacks list.
regularization_optimizer – Optimizer to use for the regularization loss.
metrics – Metrics to compute during training.
callbacks – Callbacks to use during training. Each callback will be called at different moments, see the documentation of
BaseCallback
for more information.device – Device to use for the training. Default is the device of the model.
verbose – Whether to print information during training.
kwargs – Additional arguments of the training.
- Keyword Arguments:
n_epochs (int) – The number of epochs to train at each iteration. Default is 1.
lr (float) – Learning rate of the main optimizer. Default is 1e-3.
reg_lr (float) – Learning rate of the regularization optimizer. Default is 1e-2.
weight_decay (float) – Weight decay of the main optimizer. Default is 0.0.
exec_metrics_on_train (bool) – Whether to compute metrics on the train dataset. This is useful when you want to save time by not computing the metrics on the train dataset. Default is True.
x_transform – Transform to apply to the input data before passing it to the model.
y_transform – Transform to apply to the target data before passing it to the model. For example, this can be used to convert the target data to a one-hot encoding or to long tensor using nt.ToTensor(dtype=torch.long).
neurotorch.trainers.regression module¶
- class neurotorch.trainers.regression.RegressionTrainer(*args, **kwargs)¶
Bases:
Trainer
- __init__(*args, **kwargs)¶
Constructor for Trainer.
- Parameters:
model – Model to train.
criterion – Loss function(s) to use. Deprecated, use learning_algorithm instead.
regularization –
Regularization(s) to use. In NeuroTorch, there are two ways to do regularization: 1. Regularization can be specified in the layers with the ‘update_regularization_loss’ method. This regularization will be performed by the same optimizer as the main loss. This way is useful when you want a regularization that depends on the model output or hidden state. 2. Regularization can be specified in the trainer with the ‘regularization’ parameter. This regularization will be performed by a separate optimizer named ‘regularization_optimizer’. This way is useful when you want a regularization that depends only on the model parameters and when you want to control the learning rate of the regularization independently of the main loss.
- Note: This parameter will be deprecated and remove in a future version. The regularization will be
specified in the learning algorithm and/or in the callbacks.
optimizer – Optimizer to use for the main loss. Deprecated. Use learning_algorithm instead.
learning_algorithm – Learning algorithm to use for the main loss. This learning algorithm can be given in the callbacks list as well. If specified, this learning algorithm will be added to the callbacks list. In this case, make sure that the learning algorithm is not added twice. Note that multiple learning algorithms can be used in the callbacks list.
regularization_optimizer – Optimizer to use for the regularization loss.
metrics – Metrics to compute during training.
callbacks – Callbacks to use during training. Each callback will be called at different moments, see the documentation of
BaseCallback
for more information.device – Device to use for the training. Default is the device of the model.
verbose – Whether to print information during training.
kwargs – Additional arguments of the training.
- Keyword Arguments:
n_epochs (int) – The number of epochs to train at each iteration. Default is 1.
lr (float) – Learning rate of the main optimizer. Default is 1e-3.
reg_lr (float) – Learning rate of the regularization optimizer. Default is 1e-2.
weight_decay (float) – Weight decay of the main optimizer. Default is 0.0.
exec_metrics_on_train (bool) – Whether to compute metrics on the train dataset. This is useful when you want to save time by not computing the metrics on the train dataset. Default is True.
x_transform – Transform to apply to the input data before passing it to the model.
y_transform – Transform to apply to the target data before passing it to the model. For example, this can be used to convert the target data to a one-hot encoding or to long tensor using nt.ToTensor(dtype=torch.long).
neurotorch.trainers.trainer module¶
- class neurotorch.trainers.trainer.CurrentTrainingState(n_iterations: int | None = None, iteration: int | None = None, n_epochs: int | None = None, epoch: int | None = None, epoch_loss: Any | None = None, batch: int | None = None, x_batch: Any | None = None, y_batch: Any | None = None, pred_batch: Any | None = None, batch_loss: Any | None = None, batch_is_train: bool | None = None, train_loss: Any | None = None, val_loss: Any | None = None, train_metrics: Any | None = None, val_metrics: Any | None = None, itr_metrics: Dict[str, Any] | None = {}, stop_training_flag: bool = False, info: Dict[str, Any] = {}, objects: Dict[str, Any] = {})¶
Bases:
NamedTuple
This class is used to store the current training state. It is extremely useful for the callbacks to access the current training state and to personalize the training process.
- Attributes:
n_iterations (int): The total number of iterations.
iteration (int): The current iteration.
epoch (int): The current epoch.
batch (int): The current batch.
x_batch (Any): The current input batch.
y_batch (Any): The current target batch.
pred_batch (Any): The current prediction.
batch_loss (float): The current loss.
batch_is_train (bool): Whether the current batch is a training batch.
train_loss (float): The current training loss.
val_loss (float): The current validation loss.
itr_metrics (Dict[str, Any]): The current iteration metrics.
stop_training_flag (bool): Whether the training should be stopped.
info (Dict[str, Any]): Any additional information. This is useful to communicate between callbacks.
- objects (Dict[str, Any]): Any additional objects. This is useful to manage objects between callbacks.
Note: In general, the train_dataloader and val_dataloader should be stored here.
- batch: int | None¶
Alias for field number 5
- batch_is_train: bool | None¶
Alias for field number 10
- batch_loss: Any | None¶
Alias for field number 9
- epoch: int | None¶
Alias for field number 3
- epoch_loss: Any | None¶
Alias for field number 4
- static get_null_state() CurrentTrainingState ¶
- info: Dict[str, Any]¶
Alias for field number 17
- iteration: int | None¶
Alias for field number 1
- itr_metrics: Dict[str, Any] | None¶
Alias for field number 15
- n_epochs: int | None¶
Alias for field number 2
- n_iterations: int | None¶
Alias for field number 0
- objects: Dict[str, Any]¶
Alias for field number 18
- pred_batch: Any | None¶
Alias for field number 8
- stop_training_flag: bool¶
Alias for field number 16
- train_loss: Any | None¶
Alias for field number 11
- train_metrics: Any | None¶
Alias for field number 13
- update(**kwargs) CurrentTrainingState ¶
- val_loss: Any | None¶
Alias for field number 12
- val_metrics: Any | None¶
Alias for field number 14
- x_batch: Any | None¶
Alias for field number 6
- y_batch: Any | None¶
Alias for field number 7
- class neurotorch.trainers.trainer.Trainer(model: Module, *, predict_method: str = '__call__', criterion: Dict[str, Module | Callable] | Module | Callable | None = None, regularization: BaseRegularization | RegularizationList | Iterable[BaseRegularization] | None = None, optimizer: Optimizer | None = None, learning_algorithm: LearningAlgorithm | None = None, regularization_optimizer: Optimizer | None = None, metrics: List[Callable] | None = None, callbacks: List[BaseCallback] | CallbacksList | BaseCallback | None = None, device: device | None = None, verbose: bool = True, **kwargs)¶
Bases:
object
Trainer class. This class is used to train a model.
TODO: Add the possibility to pass a callable as the predict_method.
- __init__(model: Module, *, predict_method: str = '__call__', criterion: Dict[str, Module | Callable] | Module | Callable | None = None, regularization: BaseRegularization | RegularizationList | Iterable[BaseRegularization] | None = None, optimizer: Optimizer | None = None, learning_algorithm: LearningAlgorithm | None = None, regularization_optimizer: Optimizer | None = None, metrics: List[Callable] | None = None, callbacks: List[BaseCallback] | CallbacksList | BaseCallback | None = None, device: device | None = None, verbose: bool = True, **kwargs)¶
Constructor for Trainer.
- Parameters:
model – Model to train.
criterion – Loss function(s) to use. Deprecated, use learning_algorithm instead.
regularization –
Regularization(s) to use. In NeuroTorch, there are two ways to do regularization: 1. Regularization can be specified in the layers with the ‘update_regularization_loss’ method. This regularization will be performed by the same optimizer as the main loss. This way is useful when you want a regularization that depends on the model output or hidden state. 2. Regularization can be specified in the trainer with the ‘regularization’ parameter. This regularization will be performed by a separate optimizer named ‘regularization_optimizer’. This way is useful when you want a regularization that depends only on the model parameters and when you want to control the learning rate of the regularization independently of the main loss.
- Note: This parameter will be deprecated and remove in a future version. The regularization will be
specified in the learning algorithm and/or in the callbacks.
optimizer – Optimizer to use for the main loss. Deprecated. Use learning_algorithm instead.
learning_algorithm – Learning algorithm to use for the main loss. This learning algorithm can be given in the callbacks list as well. If specified, this learning algorithm will be added to the callbacks list. In this case, make sure that the learning algorithm is not added twice. Note that multiple learning algorithms can be used in the callbacks list.
regularization_optimizer – Optimizer to use for the regularization loss.
metrics – Metrics to compute during training.
callbacks – Callbacks to use during training. Each callback will be called at different moments, see the documentation of
BaseCallback
for more information.device – Device to use for the training. Default is the device of the model.
verbose – Whether to print information during training.
kwargs – Additional arguments of the training.
- Keyword Arguments:
n_epochs (int) – The number of epochs to train at each iteration. Default is 1.
lr (float) – Learning rate of the main optimizer. Default is 1e-3.
reg_lr (float) – Learning rate of the regularization optimizer. Default is 1e-2.
weight_decay (float) – Weight decay of the main optimizer. Default is 0.0.
exec_metrics_on_train (bool) – Whether to compute metrics on the train dataset. This is useful when you want to save time by not computing the metrics on the train dataset. Default is True.
x_transform – Transform to apply to the input data before passing it to the model.
y_transform – Transform to apply to the target data before passing it to the model. For example, this can be used to convert the target data to a one-hot encoding or to long tensor using nt.ToTensor(dtype=torch.long).
- apply_criterion_on_batch(x_batch: Tensor | Dict[str, Tensor], y_batch: Tensor | Dict[str, Tensor], pred_batch: Dict[str, Tensor] | Tensor | None) Tensor ¶
- property checkpoint_managers: CallbacksList¶
- property force_overwrite¶
- get_pred_batch(x_batch: Tensor | Dict[str, Tensor])¶
- property learning_algorithms: CallbacksList¶
- property load_checkpoint_mode¶
- load_state()¶
Load the state of the trainer from the checkpoint.
- property network¶
Alias for the model.
- Returns:
The
model
attribute.
- sort_callbacks_(reverse: bool = False) CallbacksList ¶
Sort the callbacks by their priority. The higher the priority, the earlier the callback is called. In general, the callbacks will be sorted in the following order:
TrainingHistory callbacks;
Others callbacks;
CheckpointManager callbacks.
- Parameters:
reverse (bool) – Whether to reverse the order of the callbacks. Default is False.
- Returns:
The sorted callbacks.
- Return type:
- property state¶
Alias for the
current_training_state
attribute.- Returns:
The
current_training_state
- train(train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, n_iterations: int | None = None, *, n_epochs: int = 1, load_checkpoint_mode: LoadCheckpointMode | None = None, force_overwrite: bool = False, p_bar_position: int | None = None, p_bar_leave: bool | None = None, **kwargs) TrainingHistory ¶
Train the model.
- Parameters:
train_dataloader (DataLoader) – The dataloader for the training set. It contains the training data.
val_dataloader (Optional[DataLoader]) – The dataloader for the validation set. It contains the validation data.
n_iterations (Optional[int]) – The number of iterations to train the model. An iteration is a pass over the training set and the validation set. If None, the model will be trained until the training is stopped by the user.
n_epochs (int) – The number of epochs to train the model. An epoch is a pass over the training set. The nomenclature here is different from what is usually used elsewhere. Here, an epoch is a pass over the training set, while an iteration is a pass over the training set and the validation set. In other words, if n_iterations=1 and n_epochs=10, the trainer will pass 10 times over the training set and 1 time over the validation set (this will constitute 1 iteration). If n_iterations=10 and n_epochs=1, the trainer will pass 10 times over the training set and 10 times over the validation set (this will constitute 10 iterations). The nuance between those terms is really important when is comes to reinforcement learning. Default is 1.
load_checkpoint_mode (LoadCheckpointMode) – The mode to use when loading the checkpoint.
force_overwrite (bool) – Whether to force overwriting the checkpoint. Be careful when using this option, as it will destroy the previous checkpoint folder. Default is False.
p_bar_position (Optional[int]) – The position of the progress bar. See tqdm documentation for more information.
p_bar_leave (Optional[bool]) – Whether to leave the progress bar. See tqdm documentation for more information.
kwargs – Additional keyword arguments.
- Returns:
The training history.
- property training_histories: CallbacksList¶
- update_info_state_(**kwargs)¶
- update_itr_metrics_state_(**kwargs)¶
- update_objects_state_(**kwargs)¶
- update_state_(**kwargs)¶
- neurotorch.trainers.trainer.TrainingState¶
alias of
CurrentTrainingState