neurotorch.callbacks package¶
Submodules¶
neurotorch.callbacks.base_callback module¶
- class neurotorch.callbacks.base_callback.BaseCallback(priority: int | None = None, name: str | None = None, save_state: bool = True, load_state: bool | None = None, **kwargs)¶
Bases:
object
Class used to create a callback that can be used to monitor or modify the training process.
- Training Phases:
Iteration: One full pass through the training dataset and the validation dataset.
Epoch: One full pass through the training dataset or the validation dataset.
Batch: One forward pass through the network.
Train: One full pass through the training dataset.
Validation: One full pass through the validation dataset.
- Callbacks methods are called in the following order:
-
- Executes n_iterations times:
-
- Executes n_epochs times:
-
- Executes n_batches times:
- Executes n_batches times:
- Note:
The special method
get_checkpoint_state()
is called by the objectCheckpointManager
to save the state of the callback in the checkpoint file. The when this method is called is then determined by theCheckpointManager
object if it is used in the trainer callbacks. In the same way, the methodload_checkpoint_state()
is called by theCheckpointManager
to load the state of the callback from the checkpoint file if it is used in the trainer callbacks.- Note:
The special method
__del__()
is called when the callback is deleted. This is used to call theclose()
method if it was not called before.- Attributes:
Priority
: The priority of the callback. The lower the priority, the earlier the callback is called.Default is 10.
- DEFAULT_HIGH_PRIORITY = 0¶
- DEFAULT_LOW_PRIORITY = 100¶
- DEFAULT_MEDIUM_PRIORITY = 50¶
- DEFAULT_PRIORITY = 10¶
- UNPICKEABLE_ATTRIBUTES = ['trainer']¶
- __init__(priority: int | None = None, name: str | None = None, save_state: bool = True, load_state: bool | None = None, **kwargs)¶
- Parameters:
priority (int, optional) – The priority of the callback. The lower the priority, the earlier the callback is called. At the beginning of the training the priorities of the callbacks are reversed for the
load_state()
method. Default is 10.name (str, optional) – The name of the callback. If None, the name is set to the class name. Default is None.
save_state (bool, optional) – If True, the state of the callback is saved in the checkpoint file. Default is True.
load_state (bool, optional) – If True, the state of the callback is loaded from the checkpoint file. Default is equal to save_state.
- close(trainer, **kwargs)¶
Called when the training ends. This is the last callback called.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- extra_repr() str ¶
- get_checkpoint_state(trainer, **kwargs) object ¶
Get the state of the callback. This is called when the checkpoint manager saves the state of the trainer. Then this state is saved in the checkpoint file with the name of the callback as the key.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
The state of the callback.
- Return type:
An pickleable object.
- instance_counter = 0¶
- load_checkpoint_state(trainer, checkpoint: dict, **kwargs)¶
Loads the state of the callback from a dictionary.
- Parameters:
trainer (Trainer) – The trainer.
checkpoint (dict) – The dictionary containing all the states of the trainer.
- Returns:
None
- on_batch_begin(trainer, **kwargs)¶
Called when a batch starts. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_batch_end(trainer, **kwargs)¶
Called when a batch ends. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_epoch_begin(trainer, **kwargs)¶
Called when an epoch starts. An epoch is defined as one full pass through the training dataset or the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_epoch_end(trainer, **kwargs)¶
Called when an epoch ends. An epoch is defined as one full pass through the training dataset or the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_iteration_begin(trainer, **kwargs)¶
Called when an iteration starts. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_iteration_end(trainer, **kwargs)¶
Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_optimization_begin(trainer, **kwargs)¶
Called when the optimization phase of an iteration starts. The optimization phase is defined as the moment where the model weights are updated.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Keyword Arguments:
x – The input data.
y – The target data.
pred – The predicted data.
- Returns:
None
- on_optimization_end(trainer, **kwargs)¶
Called when the optimization phase of an iteration ends. The optimization phase is defined as the moment where the model weights are updated.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_pbar_update(trainer, **kwargs) dict ¶
Called when the progress bar is updated.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Returns:
None
- on_train_begin(trainer, **kwargs)¶
Called when the train phase of an iteration starts. The train phase is defined as a full pass through the training dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_train_end(trainer, **kwargs)¶
Called when the train phase of an iteration ends. The train phase is defined as a full pass through the training dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_trajectory_end(trainer, trajectory, **kwargs) List[Dict[str, Any]] ¶
Called when a trajectory ends. This is used in reinforcement learning to update the trajectory loss and metrics. Must return a list of dictionaries containing the trajectory metrics. The list must have the same length as the trajectory. Each item in the list will update the attribute others of the corresponding Experience.
- Parameters:
trainer (Trainer) – The trainer.
trajectory (Trajectory) – The trajectory i.e. the sequence of Experiences.
kwargs – Additional arguments.
- Returns:
A list of dictionaries containing the trajectory metrics.
- on_validation_batch_begin(trainer, **kwargs)¶
Called when the validation batch starts. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Keyword Arguments:
x – The input data.
y – The target data.
pred – The predicted data.
- Returns:
None
- on_validation_batch_end(trainer, **kwargs)¶
Called when the validation batch ends. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_validation_begin(trainer, **kwargs)¶
Called when the validation phase of an iteration starts. The validation phase is defined as a full pass through the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
- class neurotorch.callbacks.base_callback.CallbacksList(callbacks: Iterable[BaseCallback] | None = None)¶
Bases:
object
This class is used to store the callbacks that are used during the training. Each callback of the list is called in the order they are stored in the list.
- Attributes:
callbacks
(List[BaseCallback]): The callbacks to use.
- __init__(callbacks: Iterable[BaseCallback] | None = None)¶
Constructor of the CallbacksList class.
- Parameters:
callbacks (Iterable[BaseCallback]) – The callbacks to use.
- append(callback: BaseCallback)¶
Append a callback to the list.
- Parameters:
callback (BaseCallback) – The callback to append.
- Returns:
None
- close(trainer, **kwargs)¶
Called when the trainer closes.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- get_checkpoint_state(trainer, **kwargs) Dict[str, Any] ¶
Collates the states of the callbacks. This is called when the checkpoint manager saves the state of the trainer. Then those states are saved in the checkpoint file with the name of the callback as the key.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
The state of the callback.
- Return type:
An pickleable dict.
- load_checkpoint_state(trainer, checkpoint: dict, **kwargs)¶
Loads the state of the callback from a dictionary.
- Parameters:
trainer (Trainer) – The trainer.
checkpoint (dict) – The dictionary containing all the states of the trainer.
- Returns:
None
- on_batch_begin(trainer, **kwargs)¶
Called when a batch starts. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_batch_end(trainer, **kwargs)¶
Called when a batch ends. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_epoch_begin(trainer, **kwargs)¶
Called when an epoch starts. An epoch is defined as one full pass through the training dataset or the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_epoch_end(trainer, **kwargs)¶
Called when an epoch ends. An epoch is defined as one full pass through the training dataset or the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_iteration_begin(trainer, **kwargs)¶
Called when an iteration starts. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_iteration_end(trainer, **kwargs)¶
Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_optimization_begin(trainer, **kwargs)¶
Called when the optimization phase of an iteration starts. The optimization phase is defined as the moment where the model weights are updated.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Returns:
None
- on_optimization_end(trainer, **kwargs)¶
Called when the optimization phase of an iteration ends. The optimization phase is defined as the moment where the model weights are updated.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_pbar_update(trainer, **kwargs) dict ¶
Called when the progress bar is updated.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Returns:
None
- on_train_begin(trainer, **kwargs)¶
Called when the train phase of an iteration starts. The train phase is defined as a full pass through the training dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_train_end(trainer, **kwargs)¶
Called when the train phase of an iteration ends. The train phase is defined as a full pass through the training dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_trajectory_end(trainer, trajectory, **kwargs) List[Dict[str, Any]] ¶
Called when a trajectory ends. This is used in reinforcement learning to update the trajectory loss and metrics. Must return a list of dictionaries containing the trajectory metrics. The list must have the same length as the trajectory. Each item in the list will update the attribute others of the corresponding Experience.
- Note:
If the callbacks return the same keys in the dictionaries, the values will be updated, so the last callback will prevail.
- Parameters:
trainer (Trainer) – The trainer.
trajectory (Trajectory) – The trajectory i.e. the sequence of Experiences.
kwargs – Additional arguments.
- Returns:
A list of dictionaries containing the trajectory metrics.
- on_validation_batch_begin(trainer, **kwargs)¶
Called when the validation batch starts. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
kwargs – Additional arguments.
- Keyword Arguments:
x – The input data.
y – The target data.
pred – The predicted data.
- Returns:
None
- on_validation_batch_end(trainer, **kwargs)¶
Called when the validation batch ends. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_validation_begin(trainer, **kwargs)¶
Called when the validation phase of an iteration starts. The validation phase is defined as a full pass through the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_validation_end(trainer, **kwargs)¶
Called when the validation phase of an iteration ends. The validation phase is defined as a full pass through the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- remove(callback: BaseCallback)¶
Remove a callback from the list.
- Parameters:
callback (BaseCallback) – The callback to remove.
- Returns:
None
- sort_callbacks_(reverse: bool = False) CallbacksList ¶
Sorts the callbacks by their priority.
- Parameters:
reverse (bool) – If True, the callbacks are sorted in descending order.
- Returns:
self
- Return type:
neurotorch.callbacks.checkpoints_manager module¶
- class neurotorch.callbacks.checkpoints_manager.CheckpointManager(checkpoint_folder: str = './checkpoints', *, checkpoints_meta_path: str | None = None, meta_path_prefix: str = 'network', metric: str = 'val_loss', minimise_metric: bool = True, save_freq: int = 1, save_best_only: bool = False, start_save_at: int = 0, verbose: bool = False, **kwargs)¶
Bases:
BaseCallback
This class is used to manage and create the checkpoints of a model.
- Attributes:
checkpoint_folder (str): The folder to save the checkpoints to.
meta_path_prefix (str): The prefix to use for the checkpoint’s metadata file.
metric (str): The name of the metric to collect the best checkpoint on.
minimise_metric (bool): Whether to minimise the metric or maximise it.
curr_best_metric (float): The current best metric value.
- CHECKPOINTS_META_SUFFIX: str = 'checkpoints'¶
- CHECKPOINT_BEST_KEY: str = 'best'¶
- CHECKPOINT_FILE_STRUCT: Dict[str, str | Dict[int, str]] = {'best': 'save_path', 'iterations': {0: 'save_path'}}¶
- CHECKPOINT_ITRS_KEY: str = 'iterations'¶
- CHECKPOINT_ITR_KEY: str = 'itr'¶
- CHECKPOINT_METRICS_KEY: str = 'metrics'¶
- CHECKPOINT_OPTIMIZER_STATE_DICT_KEY: str = 'optimizer_state_dict'¶
- CHECKPOINT_SAVE_PATH_KEY: str = 'save_path'¶
- CHECKPOINT_STATE_DICT_KEY: str = 'model_state_dict'¶
- CHECKPOINT_TRAINING_HISTORY_KEY: str = 'training_history'¶
- DEFAULT_PRIORITY = 100¶
- SAVE_EXT: str = '.pth'¶
- SUFFIX_SEP: str = '-'¶
- __init__(checkpoint_folder: str = './checkpoints', *, checkpoints_meta_path: str | None = None, meta_path_prefix: str = 'network', metric: str = 'val_loss', minimise_metric: bool = True, save_freq: int = 1, save_best_only: bool = False, start_save_at: int = 0, verbose: bool = False, **kwargs)¶
Initialises the checkpoint manager.
- Parameters:
checkpoint_folder (str) – The folder to save the checkpoints to.
checkpoints_meta_path (Optional[str]) – The path to the checkpoints metadata file. If None, will use the checkpoint_folder and meta_path_prefix to create the path.
meta_path_prefix (str) – The prefix to use for the checkpoint’s metadata file.
metric (str) – The name of the metric to collect the best checkpoint on.
minimise_metric (bool) – Whether to minimise the metric or maximise it.
save_freq (int) – The frequency at which to save checkpoints. If set to <= 0, will save at the end of the training.
save_best_only (bool) – Whether to only save the best checkpoint.
start_save_at (int) – The iteration at which to start saving checkpoints.
verbose (bool) – Whether to print out the trace of the checkpoint manager.
kwargs – The keyword arguments to pass to the BaseCallback.
- property checkpoints_meta_path: str¶
Gets the path to the checkpoints metadata file.
- Returns:
The path to the checkpoints metadata file.
- Return type:
str
- close(trainer, **kwargs)¶
Called when the training is finished. Saves the current checkpoint if the current iteration is lower than the number of iterations i.e. there is new stuff to save.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- extra_repr() str ¶
- get_checkpoint_filename(itr: int = -1)¶
Generate the filename for the checkpoint at the given iteration.
- Parameters:
itr (int) – The iteration to generate the filename for.
- Returns:
The filename for the checkpoint at the given iteration.
- Return type:
str
- static get_save_name_from_checkpoints(checkpoints_meta: Dict[str, str | Dict[Any, str]], load_checkpoint_mode: LoadCheckpointMode = LoadCheckpointMode.BEST_ITR) str ¶
Gets the save name from the checkpoint’s metadata given the load checkpoint mode.
- Parameters:
checkpoints_meta (Dict[str, Union[str, Dict[Any, str]]]) – The checkpoint’s metadata.
load_checkpoint_mode (LoadCheckpointMode) – The load checkpoint mode.
- Returns:
The save name.
- Return type:
str
- load_checkpoint(load_checkpoint_mode: LoadCheckpointMode = LoadCheckpointMode.BEST_ITR) dict ¶
Loads the checkpoint at the given load_checkpoint_mode.
- Parameters:
load_checkpoint_mode (LoadCheckpointMode) – The load_checkpoint_mode to use.
- Returns:
The loaded checkpoint.
- Return type:
dict
- load_mode_to_suffix: Dict[LoadCheckpointMode, str] = {<LoadCheckpointMode.BEST_ITR: 0>: 'BEST_ITR', <LoadCheckpointMode.LAST_ITR: 1>: 'LAST_ITR'}¶
- on_iteration_end(trainer, **kwargs)¶
Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset. The checkpoint is saved if the current constraints are met.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- save_checkpoint(itr: int, itr_metrics: Dict[str, Any], best: bool = False, state_dict: Dict[str, Any] | None = None, optimizer_state_dict: Dict[str, Any] | None = None, training_history: Any | None = None, **other_states) str ¶
Saves a checkpoint of the model and optimizer states at the given iteration.
- Parameters:
itr (int) – The iteration number.
itr_metrics (Dict[str, Any]) – The metrics at the given iteration.
best (bool) – Whether this is the best iteration so far.
state_dict (Optional[Dict[str, Any]]) – The state dict of the model.
optimizer_state_dict (Optional[Dict[str, Any]]) – The state dict of the optimizer.
training_history (Optional[Any]) – The training history object.
- Returns:
The path to the saved checkpoint.
- Return type:
str
- save_checkpoints_meta(new_info: dict)¶
Saves the new checkpoints’ metadata.
- Parameters:
new_info (dict) – The new checkpoints’ metadata.
- Returns:
None
- class neurotorch.callbacks.checkpoints_manager.LoadCheckpointMode(value)¶
Bases:
Enum
Enum for the different modes of loading a checkpoint.
- Attributes:
BEST_ITR (int): Indicates that the iteration with the best metric will be loaded.
LAST_ITR (int): Indicates that the last iteration will be loaded.
- BEST_ITR = 0¶
- LAST_ITR = 1¶
- static from_str(mode_name: str) LoadCheckpointMode ¶
Converts a string to a
LoadCheckpointMode
instance.- Parameters:
mode_name (str) – The name of the mode.
- Returns:
The corresponding
LoadCheckpointMode
instance.- Return type:
neurotorch.callbacks.convergence module¶
- class neurotorch.callbacks.convergence.ConvergenceTimeGetter(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)¶
Bases:
BaseCallback
Monitor the training process and return the time it took to pass the threshold.
- __init__(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)¶
Constructor for ConvergenceTimeGetter class.
- Parameters:
metric (str) – Name of the metric to monitor.
threshold (float) – Threshold value for the metric.
minimize_metric (bool) – Whether to minimize or maximize the metric.
kwargs – The keyword arguments to pass to the BaseCallback.
- close(trainer, **kwargs)¶
Called when the training ends. This is the last callback called.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- load_checkpoint_state(trainer, checkpoint: dict, **kwargs)¶
Loads the state of the callback from a dictionary.
- Parameters:
trainer (Trainer) – The trainer.
checkpoint (dict) – The dictionary containing all the states of the trainer.
- Returns:
None
- on_iteration_end(trainer, **kwargs)¶
Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- save_on(trainer, **kwargs)¶
neurotorch.callbacks.early_stopping module¶
- class neurotorch.callbacks.early_stopping.EarlyStopping(patience: int = 5, tol: float = 0.01)¶
Bases:
BaseCallback
- __init__(patience: int = 5, tol: float = 0.01)¶
- Parameters:
priority (int, optional) – The priority of the callback. The lower the priority, the earlier the callback is called. At the beginning of the training the priorities of the callbacks are reversed for the
load_state()
method. Default is 10.name (str, optional) – The name of the callback. If None, the name is set to the class name. Default is None.
save_state (bool, optional) – If True, the state of the callback is saved in the checkpoint file. Default is True.
load_state (bool, optional) – If True, the state of the callback is loaded from the checkpoint file. Default is equal to save_state.
- class neurotorch.callbacks.early_stopping.EarlyStoppingOnNaN(metric: str, **kwargs)¶
Bases:
BaseCallback
Monitor the training process and set the stop_training_flag to True when the metric is NaN.
- DEFAULT_PRIORITY = 100¶
- __init__(metric: str, **kwargs)¶
Constructor for EarlyStoppingOnNaN class.
- Parameters:
metric (str) – Name of the metric to monitor.
kwargs – The keyword arguments to pass to the BaseCallback.
- extra_repr() str ¶
- class neurotorch.callbacks.early_stopping.EarlyStoppingOnStagnation(metric: str, patience: int = 10, tol: float = 0.0001, start_with_history: bool = True, **kwargs)¶
Bases:
BaseCallback
Monitor the training process and set the stop_training_flag to True when the metric stagnates. The metric is considered to stagnate when the mean of the absolute difference between the last patience iterations is less than tol.
- DEFAULT_PRIORITY = 100¶
- __init__(metric: str, patience: int = 10, tol: float = 0.0001, start_with_history: bool = True, **kwargs)¶
Constructor for EarlyStoppingOnStagnation class.
- Parameters:
metric (str) – Name of the metric to monitor.
patience (int) – Number of iterations to wait before stopping the training.
tol (float) – The tolerance for the metric.
start_with_history (bool) – Whether to start the monitor with the history of the metric.
kwargs – The keyword arguments to pass to the BaseCallback.
- extra_repr() str ¶
- get_value()¶
- class neurotorch.callbacks.early_stopping.EarlyStoppingOnTimeLimit(*, delta_seconds: float = 600.0, resume_on_load: bool = True, **kwargs)¶
Bases:
BaseCallback
Monitor the training process and set the stop_training_flag to True when the threshold is met.
- CURRENT_SECONDS_COUNT_KEY = 'current_seconds_count'¶
- DELTA_SECONDS_KEY = 'delta_seconds'¶
- __init__(*, delta_seconds: float = 600.0, resume_on_load: bool = True, **kwargs)¶
Constructor for EarlyStoppingThreshold class.
- Parameters:
delta_seconds (float) – The number of seconds to wait before stopping the training.
resume_on_load (bool) – Whether to resume the time when loading a checkpoint. If False, the time will be reset to 0.
kwargs – The keyword arguments to pass to the BaseCallback.
- extra_repr() str ¶
- get_checkpoint_state(trainer, **kwargs) object ¶
Get the state of the callback. This is called when the checkpoint manager saves the state of the trainer. Then this state is saved in the checkpoint file with the name of the callback as the key.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
The state of the callback.
- Return type:
An pickleable object.
- load_checkpoint_state(trainer, checkpoint: dict, **kwargs)¶
Loads the state of the callback from a dictionary.
- Parameters:
trainer (Trainer) – The trainer.
checkpoint (dict) – The dictionary containing all the states of the trainer.
- Returns:
None
- on_iteration_end(trainer, **kwargs)¶
Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- start(trainer, **kwargs)¶
Called when the training starts. This is the first callback called.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- update_flags(trainer, **kwargs)¶
- class neurotorch.callbacks.early_stopping.EarlyStoppingThreshold(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)¶
Bases:
BaseCallback
Monitor the training process and set the stop_training_flag to True when the threshold is met.
- __init__(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)¶
Constructor for EarlyStoppingThreshold class.
- Parameters:
metric (str) – Name of the metric to monitor.
threshold (float) – Threshold value for the metric.
minimize_metric (bool) – Whether to minimize or maximize the metric.
kwargs – The keyword arguments to pass to the BaseCallback.
neurotorch.callbacks.events module¶
- class neurotorch.callbacks.events.EventOnMetricThreshold(metric_name: str, threshold: float, *, event: Callable, event_args: tuple | None = None, event_kwargs: dict | None = None, minimize_metric: bool = True, do_once: bool = False, **kwargs)¶
Bases:
BaseCallback
This class is a callback that call an event if the metric value reach a given threshold at the end of each iteration. Note that the trainer will be passed as the first argument to the event. In addition, if the given metric is not present in the training state, the event will not be called. Finally, the output of the event will be pass to the
on_pbar_update()
method.- Attributes:
metric_name (str): The name of the metric to monitor.
threshold (float): The threshold value.
event (Callable): The event to call.
event_args (tuple): The arguments to pass to the event.
event_kwargs (dict): The keyword arguments to pass to the event.
- __init__(metric_name: str, threshold: float, *, event: Callable, event_args: tuple | None = None, event_kwargs: dict | None = None, minimize_metric: bool = True, do_once: bool = False, **kwargs)¶
Constructor for the EventOnMetricThreshold class.
- Parameters:
metric_name (str) – The name of the metric to monitor.
threshold (float) – The threshold value.
event (Callable) – The event to call.
event_args (tuple) – The arguments to pass to the event.
event_kwargs (dict) – The keyword arguments to pass to the event.
kwargs – The keyword arguments to pass to the base class.
neurotorch.callbacks.history module¶
- class neurotorch.callbacks.history.TrainingHistory(container: Dict[str, List[float]] | None = None, default_value=nan, **kwargs)¶
Bases:
BaseCallback
This class is used to store some metrics over the training process.
- Attributes:
default_value (float): The default value to use to equalize the lengths of the container’s items.
- DEFAULT_PRIORITY = 0¶
- __init__(container: Dict[str, List[float]] | None = None, default_value=nan, **kwargs)¶
Initialize the container with the given container.
- Parameters:
container (Dict[str, List[float]]) – The container to initialize the container with.
default_value (float) – The default value to use to equalize the lengths of the container’s items.
kwargs – The keyword arguments to pass to the BaseCallback.
- append(key, value)¶
Add the given value to the given key. Increase the size of the container by one.
- Parameters:
key (str) – The key to add the value to.
value (float) – The value to add.
- Returns:
None
- concat(other)¶
- create_plot(**kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]] ¶
Create a plot of the metrics in the container.
- Parameters:
kwargs – Keyword arguments.
- Keyword Arguments:
figsize (Tuple[float, float]) – The size of the figure.
linewidth (int) – The width of the lines.
- Returns:
The figure, axes and lines of the plot.
- Return type:
Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]
- extra_repr()¶
- get(key, default=None)¶
Get the values of the given key.
- Parameters:
key (str) – The key to get the values of.
default (Any) – The default value to return if the key is not in the container.
- Returns:
The values of the given key.
- Return type:
List[float]
- get_item_at(idx: int = -1)¶
Get all the metrics of the iteration at the given index.
- Parameters:
idx (int) – The index to get the metrics of.
- Returns:
All the metrics of the iteration at the given index.
- Return type:
Dict[str, float]
- Raises:
ValueError – If the index is out of bounds.
- insert(index: int, other)¶
Increase the size of the container items to the given index and insert the given other into the container.
- Parameters:
index (int) – The index to insert the other at.
other (Dict[str, float]) – The other to insert.
- Returns:
None
- items()¶
- keys()¶
- max(key=None, default=-inf)¶
Get the maximum value of the given key.
- Parameters:
key (str) – The key to get the maximum value of. If None, the first key is used.
default (float) – The default value to return if the key is not in the container.
- Returns:
The maximum value of the given key.
- Return type:
float
- max_item(key=None)¶
Get all the metrics of the iteration with the maximum value of the given key.
- Parameters:
key (str) – The key to get the maximum value of. If None, the first key is used.
- Returns:
All the metrics of the iteration with the maximum value of the given key.
- Return type:
Dict[str, float]
- Raises:
ValueError – If the key is not in the container.
- min(key=None, default: float = inf) float ¶
Get the minimum value of the given key.
- Parameters:
key (str) – The key to get the minimum value of. If None, the first key is used.
default (float) – The default value to return if the key is not in the container.
- Returns:
The minimum value of the given key.
- Return type:
float
- min_item(key=None) Dict[str, float] ¶
Get all the metrics of the iteration with the minimum value of the given key.
- Parameters:
key (str) – The key to get the minimum value of. If None, the first key is used.
- Returns:
All the metrics of the iteration with the minimum value of the given key.
- Return type:
Dict[str, float]
- Raises:
ValueError – If the key is not in the container.
- on_iteration_end(trainer, **kwargs)¶
Insert the current metrics into the container.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- plot(save_path: str | None = None, show: bool = False, **kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]] ¶
Plot the metrics in the container.
- Parameters:
save_path (Optional[str]) – The path to save the plot to. If None, the plot is not saved.
show (bool) – Whether to show the plot.
kwargs – Keyword arguments.
- Keyword Arguments:
figsize (Tuple[float, float]) – The size of the figure.
linewidth (int) – The width of the lines.
dpi (int) – The resolution of the figure.
block (bool) – Whether to block execution until the plot is closed.
close (bool) – Whether to close the plot at the end.
- Returns:
The figure, axes and lines of the plot.
- Return type:
Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]
- update_fig(fig: Figure, axes: Dict[str, Axes], lines: Dict[str, Line2D], **kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]] ¶
Update the plot of the metrics in the container.
- Parameters:
fig (plt.Figure) – The figure to update.
axes (Dict[str, plt.Axes]) – The axes to update.
lines (Dict[str, plt.Line2D]) – The lines to update.
kwargs – Keyword arguments.
- Returns:
The figure, axes and lines of the plot.
- Return type:
Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]
neurotorch.callbacks.lr_schedulers module¶
- class neurotorch.callbacks.lr_schedulers.LRSchedulerOnMetric(metric: str, metric_schedule: Iterable[float], *, minimize_metric: bool | None = None, lr_decay: float | Sequence[float] | None = None, min_lr: float | Sequence[float] | None = None, lr_start: float | Sequence[float] | None = None, retain_progress: bool = True, optimizer: Optimizer | None = None, **kwargs)¶
Bases:
BaseCallback
Class to schedule the learning rate of the optimizer based on the metric value. Each time the metric reach the next value of the schedule, the learning rate is multiplied by the given decay. The learning rate is also capped at the given minimum value.
- Attributes:
metric (str): The metric to use to schedule the learning rate.
metric_schedule (Iterable[float]): The schedule of the metric.
minimize_metric (bool): Whether to minimize the metric or maximize it.
lr_decay (float): The decay factor to use when the metric reach the next value of the schedule.
min_lr (float): The minimum learning rate to use.
lr_start (float): The learning rate to use at the beginning of the training.
lr (float): The current learning rate.
retain_progress (bool): If True the current step of the scheduler will only increase when the metric reach
the next value of the schedule. If False, the current step will increase or decrease depending on the metric. - step (int): The current step of the scheduler.
- DEFAULT_LR_START = 0.001¶
- DEFAULT_MIN_LR = 1e-12¶
- __init__(metric: str, metric_schedule: Iterable[float], *, minimize_metric: bool | None = None, lr_decay: float | Sequence[float] | None = None, min_lr: float | Sequence[float] | None = None, lr_start: float | Sequence[float] | None = None, retain_progress: bool = True, optimizer: Optimizer | None = None, **kwargs)¶
Initialize the scheduler with the given metric and metric schedule.
- Parameters:
metric (str) – The metric to use to schedule the learning rate.
metric_schedule (Iterable[float]) – The schedule of the metric.
minimize_metric (Optional[bool]) – Whether to minimize the metric or maximize it. If None, infer from the metric schedule.
lr_decay – The decay factor to use when the metric reach the next value of the schedule. If None, the
decay is computed automatically as (lr_start - min_lr) / len(metric_schedule). :type lr_decay: Optional[Union[float, Sequence[float]]] :param min_lr: The minimum learning rate to use. :type min_lr: Optional[Union[float, Sequence[float]]] :param lr_start: The learning rate to use at the beginning of the training. If None, the learning rate is get automatically as the learning rate of the first group of the optimizer. :type lr_start: Optional[Union[float, Sequence[float]]] :param retain_progress: If True the current step of the scheduler will only increase when the metric reach the next value of the schedule. If False, the current step will increase or decrease depending on the metric. :type retain_progress: bool :param optimizer: The optimizer whose learning rate will be scheduled. If None, the optimizer is get from the
trainer. Note that in this case the first optimizer of the trainer’s callbacks will be used.
- Parameters:
kwargs – The keyword arguments to pass to the BaseCallback.
- Keyword Arguments:
log_lr_to_history (bool) – Whether to log the learning rate to the training history. Default: True.
- extra_repr() str ¶
- on_iteration_end(trainer, **kwargs)¶
Update the learning rate of the optimizer based on the metric value.
- Parameters:
trainer (Trainer) – The trainer object.
- Returns:
None
- on_pbar_update(trainer, **kwargs) dict ¶
Return the learning rate to display in the progress bar.
- Parameters:
trainer (Trainer) – The trainer object.
- Returns:
The dictionary to update the progress bar.
- Return type:
dict
- start(trainer, **kwargs)¶
Initialize the learning rate of the optimizer and the
lr_start
attribute if necessary.- Parameters:
trainer (Trainer) – The trainer object.
- Returns:
None
- update_step(last_metric: float) int ¶
Update the current step of the scheduler based on the metric value.
- Parameters:
last_metric (float) – The last value of the metric.
- Returns:
The new step.
- Return type:
int
- class neurotorch.callbacks.lr_schedulers.LinearLRScheduler(lr_start: float | Sequence[float], lr_end: float | Sequence[float], n_steps: int, optimizer: Optimizer | None = None, **kwargs)¶
Bases:
BaseCallback
This class is a callback that implements a linear learning rate decay. This is useful to decrease the learning rate over iterations. The learning rate is decreased linearly from lr_start to lr_end over n_steps iterations.
- Attributes:
lr_start (float): The initial learning rate.
lr_end (float): The final learning rate.
n_steps (int): The number of steps over which the learning rate is decreased.
lr (float): The current learning rate.
lr_decay (float): The learning rate decay per step.
- DEFAULT_LR_START = 0.001¶
- DEFAULT_MIN_LR = 1e-12¶
- __init__(lr_start: float | Sequence[float], lr_end: float | Sequence[float], n_steps: int, optimizer: Optimizer | None = None, **kwargs)¶
Construcor for the LinearLRScheduler class.
- Parameters:
lr_start (Union[float, Sequence[float]]) – The initial learning rate. If a sequence is given, each entry of the sequence is used for each parameter group.
lr_end (Union[float, Sequence[float]]) – The final learning rate. If a sequence is given, each entry of the sequence is used for each parameter group.
n_steps (int) – The number of steps over which the learning rate is decreased.
- extra_repr() str ¶
neurotorch.callbacks.training_visualization module¶
- class neurotorch.callbacks.training_visualization.TrainingHistoryVisualizationCallback(temp_folder: str = '~/temp/', **kwargs)¶
Bases:
BaseCallback
This callback is used to visualize the training history in real time.
- Attributes:
fig (plt.Figure): The figure to plot.
axes (plt.Axes): The axes to plot.
lines (list): The list of lines to plot.
- __init__(temp_folder: str = '~/temp/', **kwargs)¶
Create a new callback to visualize the training history.
- Parameters:
temp_folder (str) – The folder where to save the training history.
kwargs – The keyword arguments to pass to the base callback.
- close(trainer, **kwargs)¶
Close the process adn delete the temporary file.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
Module contents¶
- class neurotorch.callbacks.ForesightTimeStepUpdaterOnTarget(**kwargs)¶
Bases:
BaseCallback
Updates the foresight time step of the model to the length of the target sequence. This is useful for models that are trained with a fixed foresight time step, but are evaluated on sequences of different lengths.
- DEFAULT_PRIORITY = 0¶
- __init__(**kwargs)¶
Constructor for ForesightTimeStepUpdaterOnTarget class.
- Keyword Arguments:
time_steps_multiplier (int) – The multiplier to use to determine the time steps. Defaults to 1.
target_skip_first (bool) – Whether to skip the first time step of the target sequence. Defaults to True.
update_val_loss_freq (int) – The frequency at which to update the validation loss. Defaults to 1.
start_intensive_val_at (float) – The fraction of the training epochs at which to start intensive validation. An intensive validation is when the validation is performed at each iteration. Defaults to 0.0. If the value is an integer, it is interpreted as the number of iterations at which to start intensive validation.
hh_memory_size_strategy (str) –
The strategy to use to determine the hidden history memory size. The available strategies are:
”out_memory_size”: The hidden history memory size is equal to the output memory size.
”foresight_time_steps”: The hidden history memory size is equal to the foresight time steps.
<Number>: The hidden history memory size is equal to the specified number.
Defaults to “out_memory_size”.
- get_hh_memory_size_from_y_batch(y_batch) int ¶
- on_batch_begin(trainer, **kwargs)¶
Called when a batch starts. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None
- on_batch_end(trainer, **kwargs)¶
Called when a batch ends. The batch is defined as one forward pass through the network.
- Parameters:
trainer (Trainer) – The trainer.
- Returns:
None