NeuroTorch

neurotorch.callbacks package

Submodules

neurotorch.callbacks.base_callback module

class neurotorch.callbacks.base_callback.BaseCallback(priority: int | None = None, name: str | None = None, save_state: bool = True, load_state: bool | None = None, **kwargs)

Bases: object

Class used to create a callback that can be used to monitor or modify the training process.

Training Phases:
  • Iteration: One full pass through the training dataset and the validation dataset.

  • Epoch: One full pass through the training dataset or the validation dataset.

  • Batch: One forward pass through the network.

  • Train: One full pass through the training dataset.

  • Validation: One full pass through the validation dataset.

Callbacks methods are called in the following order:
Note:

The special method get_checkpoint_state() is called by the object CheckpointManager to save the state of the callback in the checkpoint file. The when this method is called is then determined by the CheckpointManager object if it is used in the trainer callbacks. In the same way, the method load_checkpoint_state() is called by the CheckpointManager to load the state of the callback from the checkpoint file if it is used in the trainer callbacks.

Note:

The special method __del__() is called when the callback is deleted. This is used to call the close() method if it was not called before.

Attributes:
  • Priority: The priority of the callback. The lower the priority, the earlier the callback is called.

    Default is 10.

DEFAULT_HIGH_PRIORITY = 0
DEFAULT_LOW_PRIORITY = 100
DEFAULT_MEDIUM_PRIORITY = 50
DEFAULT_PRIORITY = 10
UNPICKEABLE_ATTRIBUTES = ['trainer']
__init__(priority: int | None = None, name: str | None = None, save_state: bool = True, load_state: bool | None = None, **kwargs)
Parameters:
  • priority (int, optional) – The priority of the callback. The lower the priority, the earlier the callback is called. At the beginning of the training the priorities of the callbacks are reversed for the load_state() method. Default is 10.

  • name (str, optional) – The name of the callback. If None, the name is set to the class name. Default is None.

  • save_state (bool, optional) – If True, the state of the callback is saved in the checkpoint file. Default is True.

  • load_state (bool, optional) – If True, the state of the callback is loaded from the checkpoint file. Default is equal to save_state.

close(trainer, **kwargs)

Called when the training ends. This is the last callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

extra_repr() str
get_checkpoint_state(trainer, **kwargs) object

Get the state of the callback. This is called when the checkpoint manager saves the state of the trainer. Then this state is saved in the checkpoint file with the name of the callback as the key.

Parameters:

trainer (Trainer) – The trainer.

Returns:

The state of the callback.

Return type:

An pickleable object.

instance_counter = 0
load_checkpoint_state(trainer, checkpoint: dict, **kwargs)

Loads the state of the callback from a dictionary.

Parameters:
  • trainer (Trainer) – The trainer.

  • checkpoint (dict) – The dictionary containing all the states of the trainer.

Returns:

None

on_batch_begin(trainer, **kwargs)

Called when a batch starts. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_batch_end(trainer, **kwargs)

Called when a batch ends. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_epoch_begin(trainer, **kwargs)

Called when an epoch starts. An epoch is defined as one full pass through the training dataset or the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_epoch_end(trainer, **kwargs)

Called when an epoch ends. An epoch is defined as one full pass through the training dataset or the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_iteration_begin(trainer, **kwargs)

Called when an iteration starts. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_optimization_begin(trainer, **kwargs)

Called when the optimization phase of an iteration starts. The optimization phase is defined as the moment where the model weights are updated.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Keyword Arguments:
  • x – The input data.

  • y – The target data.

  • pred – The predicted data.

Returns:

None

on_optimization_end(trainer, **kwargs)

Called when the optimization phase of an iteration ends. The optimization phase is defined as the moment where the model weights are updated.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_pbar_update(trainer, **kwargs) dict

Called when the progress bar is updated.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Returns:

None

on_train_begin(trainer, **kwargs)

Called when the train phase of an iteration starts. The train phase is defined as a full pass through the training dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_train_end(trainer, **kwargs)

Called when the train phase of an iteration ends. The train phase is defined as a full pass through the training dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_trajectory_end(trainer, trajectory, **kwargs) List[Dict[str, Any]]

Called when a trajectory ends. This is used in reinforcement learning to update the trajectory loss and metrics. Must return a list of dictionaries containing the trajectory metrics. The list must have the same length as the trajectory. Each item in the list will update the attribute others of the corresponding Experience.

Parameters:
  • trainer (Trainer) – The trainer.

  • trajectory (Trajectory) – The trajectory i.e. the sequence of Experiences.

  • kwargs – Additional arguments.

Returns:

A list of dictionaries containing the trajectory metrics.

on_validation_batch_begin(trainer, **kwargs)

Called when the validation batch starts. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Keyword Arguments:
  • x – The input data.

  • y – The target data.

  • pred – The predicted data.

Returns:

None

on_validation_batch_end(trainer, **kwargs)

Called when the validation batch ends. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_validation_begin(trainer, **kwargs)

Called when the validation phase of an iteration starts. The validation phase is defined as a full pass through the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

on_validation_end(trainer, **kwargs)

Called when the validation phase of an iteration ends. The validation phase is defined as a full pass through the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

start(trainer, **kwargs)

Called when the training starts. This is the first callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

class neurotorch.callbacks.base_callback.CallbacksList(callbacks: Iterable[BaseCallback] | None = None)

Bases: object

This class is used to store the callbacks that are used during the training. Each callback of the list is called in the order they are stored in the list.

Attributes:
  • callbacks (List[BaseCallback]): The callbacks to use.

__init__(callbacks: Iterable[BaseCallback] | None = None)

Constructor of the CallbacksList class.

Parameters:

callbacks (Iterable[BaseCallback]) – The callbacks to use.

append(callback: BaseCallback)

Append a callback to the list.

Parameters:

callback (BaseCallback) – The callback to append.

Returns:

None

close(trainer, **kwargs)

Called when the trainer closes.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

get_checkpoint_state(trainer, **kwargs) Dict[str, Any]

Collates the states of the callbacks. This is called when the checkpoint manager saves the state of the trainer. Then those states are saved in the checkpoint file with the name of the callback as the key.

Parameters:

trainer (Trainer) – The trainer.

Returns:

The state of the callback.

Return type:

An pickleable dict.

load_checkpoint_state(trainer, checkpoint: dict, **kwargs)

Loads the state of the callback from a dictionary.

Parameters:
  • trainer (Trainer) – The trainer.

  • checkpoint (dict) – The dictionary containing all the states of the trainer.

Returns:

None

on_batch_begin(trainer, **kwargs)

Called when a batch starts. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_batch_end(trainer, **kwargs)

Called when a batch ends. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_epoch_begin(trainer, **kwargs)

Called when an epoch starts. An epoch is defined as one full pass through the training dataset or the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_epoch_end(trainer, **kwargs)

Called when an epoch ends. An epoch is defined as one full pass through the training dataset or the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_iteration_begin(trainer, **kwargs)

Called when an iteration starts. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_optimization_begin(trainer, **kwargs)

Called when the optimization phase of an iteration starts. The optimization phase is defined as the moment where the model weights are updated.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Returns:

None

on_optimization_end(trainer, **kwargs)

Called when the optimization phase of an iteration ends. The optimization phase is defined as the moment where the model weights are updated.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_pbar_update(trainer, **kwargs) dict

Called when the progress bar is updated.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Returns:

None

on_train_begin(trainer, **kwargs)

Called when the train phase of an iteration starts. The train phase is defined as a full pass through the training dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_train_end(trainer, **kwargs)

Called when the train phase of an iteration ends. The train phase is defined as a full pass through the training dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_trajectory_end(trainer, trajectory, **kwargs) List[Dict[str, Any]]

Called when a trajectory ends. This is used in reinforcement learning to update the trajectory loss and metrics. Must return a list of dictionaries containing the trajectory metrics. The list must have the same length as the trajectory. Each item in the list will update the attribute others of the corresponding Experience.

Note:

If the callbacks return the same keys in the dictionaries, the values will be updated, so the last callback will prevail.

Parameters:
  • trainer (Trainer) – The trainer.

  • trajectory (Trajectory) – The trajectory i.e. the sequence of Experiences.

  • kwargs – Additional arguments.

Returns:

A list of dictionaries containing the trajectory metrics.

on_validation_batch_begin(trainer, **kwargs)

Called when the validation batch starts. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.

Parameters:
  • trainer (Trainer) – The trainer.

  • kwargs – Additional arguments.

Keyword Arguments:
  • x – The input data.

  • y – The target data.

  • pred – The predicted data.

Returns:

None

on_validation_batch_end(trainer, **kwargs)

Called when the validation batch ends. The validation batch is defined as one forward pass through the network on the validation dataset. This is used to update the batch loss and metrics on the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_validation_begin(trainer, **kwargs)

Called when the validation phase of an iteration starts. The validation phase is defined as a full pass through the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_validation_end(trainer, **kwargs)

Called when the validation phase of an iteration ends. The validation phase is defined as a full pass through the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

remove(callback: BaseCallback)

Remove a callback from the list.

Parameters:

callback (BaseCallback) – The callback to remove.

Returns:

None

sort_callbacks_(reverse: bool = False) CallbacksList

Sorts the callbacks by their priority.

Parameters:

reverse (bool) – If True, the callbacks are sorted in descending order.

Returns:

self

Return type:

CallbacksList

start(trainer, **kwargs)

Called when the trainer starts.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

neurotorch.callbacks.checkpoints_manager module

class neurotorch.callbacks.checkpoints_manager.CheckpointManager(checkpoint_folder: str = './checkpoints', *, checkpoints_meta_path: str | None = None, meta_path_prefix: str = 'network', metric: str = 'val_loss', minimise_metric: bool = True, save_freq: int = 1, save_best_only: bool = False, start_save_at: int = 0, verbose: bool = False, **kwargs)

Bases: BaseCallback

This class is used to manage and create the checkpoints of a model.

Attributes:
  • checkpoint_folder (str): The folder to save the checkpoints to.

  • meta_path_prefix (str): The prefix to use for the checkpoint’s metadata file.

  • metric (str): The name of the metric to collect the best checkpoint on.

  • minimise_metric (bool): Whether to minimise the metric or maximise it.

  • curr_best_metric (float): The current best metric value.

CHECKPOINTS_META_SUFFIX: str = 'checkpoints'
CHECKPOINT_BEST_KEY: str = 'best'
CHECKPOINT_FILE_STRUCT: Dict[str, str | Dict[int, str]] = {'best': 'save_path', 'iterations': {0: 'save_path'}}
CHECKPOINT_ITRS_KEY: str = 'iterations'
CHECKPOINT_ITR_KEY: str = 'itr'
CHECKPOINT_METRICS_KEY: str = 'metrics'
CHECKPOINT_OPTIMIZER_STATE_DICT_KEY: str = 'optimizer_state_dict'
CHECKPOINT_SAVE_PATH_KEY: str = 'save_path'
CHECKPOINT_STATE_DICT_KEY: str = 'model_state_dict'
CHECKPOINT_TRAINING_HISTORY_KEY: str = 'training_history'
DEFAULT_PRIORITY = 100
SAVE_EXT: str = '.pth'
SUFFIX_SEP: str = '-'
__init__(checkpoint_folder: str = './checkpoints', *, checkpoints_meta_path: str | None = None, meta_path_prefix: str = 'network', metric: str = 'val_loss', minimise_metric: bool = True, save_freq: int = 1, save_best_only: bool = False, start_save_at: int = 0, verbose: bool = False, **kwargs)

Initialises the checkpoint manager.

Parameters:
  • checkpoint_folder (str) – The folder to save the checkpoints to.

  • checkpoints_meta_path (Optional[str]) – The path to the checkpoints metadata file. If None, will use the checkpoint_folder and meta_path_prefix to create the path.

  • meta_path_prefix (str) – The prefix to use for the checkpoint’s metadata file.

  • metric (str) – The name of the metric to collect the best checkpoint on.

  • minimise_metric (bool) – Whether to minimise the metric or maximise it.

  • save_freq (int) – The frequency at which to save checkpoints. If set to <= 0, will save at the end of the training.

  • save_best_only (bool) – Whether to only save the best checkpoint.

  • start_save_at (int) – The iteration at which to start saving checkpoints.

  • verbose (bool) – Whether to print out the trace of the checkpoint manager.

  • kwargs – The keyword arguments to pass to the BaseCallback.

property checkpoints_meta_path: str

Gets the path to the checkpoints metadata file.

Returns:

The path to the checkpoints metadata file.

Return type:

str

close(trainer, **kwargs)

Called when the training is finished. Saves the current checkpoint if the current iteration is lower than the number of iterations i.e. there is new stuff to save.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

extra_repr() str
get_checkpoint_filename(itr: int = -1)

Generate the filename for the checkpoint at the given iteration.

Parameters:

itr (int) – The iteration to generate the filename for.

Returns:

The filename for the checkpoint at the given iteration.

Return type:

str

static get_save_name_from_checkpoints(checkpoints_meta: Dict[str, str | Dict[Any, str]], load_checkpoint_mode: LoadCheckpointMode = LoadCheckpointMode.BEST_ITR) str

Gets the save name from the checkpoint’s metadata given the load checkpoint mode.

Parameters:
  • checkpoints_meta (Dict[str, Union[str, Dict[Any, str]]]) – The checkpoint’s metadata.

  • load_checkpoint_mode (LoadCheckpointMode) – The load checkpoint mode.

Returns:

The save name.

Return type:

str

load_checkpoint(load_checkpoint_mode: LoadCheckpointMode = LoadCheckpointMode.BEST_ITR) dict

Loads the checkpoint at the given load_checkpoint_mode.

Parameters:

load_checkpoint_mode (LoadCheckpointMode) – The load_checkpoint_mode to use.

Returns:

The loaded checkpoint.

Return type:

dict

load_mode_to_suffix: Dict[LoadCheckpointMode, str] = {<LoadCheckpointMode.BEST_ITR: 0>: 'BEST_ITR', <LoadCheckpointMode.LAST_ITR: 1>: 'LAST_ITR'}
on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset. The checkpoint is saved if the current constraints are met.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

save_checkpoint(itr: int, itr_metrics: Dict[str, Any], best: bool = False, state_dict: Dict[str, Any] | None = None, optimizer_state_dict: Dict[str, Any] | None = None, training_history: Any | None = None, **other_states) str

Saves a checkpoint of the model and optimizer states at the given iteration.

Parameters:
  • itr (int) – The iteration number.

  • itr_metrics (Dict[str, Any]) – The metrics at the given iteration.

  • best (bool) – Whether this is the best iteration so far.

  • state_dict (Optional[Dict[str, Any]]) – The state dict of the model.

  • optimizer_state_dict (Optional[Dict[str, Any]]) – The state dict of the optimizer.

  • training_history (Optional[Any]) – The training history object.

Returns:

The path to the saved checkpoint.

Return type:

str

save_checkpoints_meta(new_info: dict)

Saves the new checkpoints’ metadata.

Parameters:

new_info (dict) – The new checkpoints’ metadata.

Returns:

None

save_on(trainer) bool

Saves the checkpoint if the current iteration is a checkpoint iteration.

Parameters:

trainer (Trainer) – The trainer.

Returns:

Whether the checkpoint was saved.

Return type:

bool

start(trainer, **kwargs)

Call at the beginning of the training by the Trainer. Load the checkpoint base on the load_checkpoint_mode of the trainer and update the current_training_state of the trainer.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

class neurotorch.callbacks.checkpoints_manager.LoadCheckpointMode(value)

Bases: Enum

Enum for the different modes of loading a checkpoint.

Attributes:
  • BEST_ITR (int): Indicates that the iteration with the best metric will be loaded.

  • LAST_ITR (int): Indicates that the last iteration will be loaded.

BEST_ITR = 0
LAST_ITR = 1
static from_str(mode_name: str) LoadCheckpointMode

Converts a string to a LoadCheckpointMode instance.

Parameters:

mode_name (str) – The name of the mode.

Returns:

The corresponding LoadCheckpointMode instance.

Return type:

LoadCheckpointMode

neurotorch.callbacks.convergence module

class neurotorch.callbacks.convergence.ConvergenceTimeGetter(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)

Bases: BaseCallback

Monitor the training process and return the time it took to pass the threshold.

__init__(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)

Constructor for ConvergenceTimeGetter class.

Parameters:
  • metric (str) – Name of the metric to monitor.

  • threshold (float) – Threshold value for the metric.

  • minimize_metric (bool) – Whether to minimize or maximize the metric.

  • kwargs – The keyword arguments to pass to the BaseCallback.

close(trainer, **kwargs)

Called when the training ends. This is the last callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

load_checkpoint_state(trainer, checkpoint: dict, **kwargs)

Loads the state of the callback from a dictionary.

Parameters:
  • trainer (Trainer) – The trainer.

  • checkpoint (dict) – The dictionary containing all the states of the trainer.

Returns:

None

on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

save_on(trainer, **kwargs)
start(trainer, **kwargs)

Called when the training starts. This is the first callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

neurotorch.callbacks.early_stopping module

class neurotorch.callbacks.early_stopping.EarlyStopping(patience: int = 5, tol: float = 0.01)

Bases: BaseCallback

__init__(patience: int = 5, tol: float = 0.01)
Parameters:
  • priority (int, optional) – The priority of the callback. The lower the priority, the earlier the callback is called. At the beginning of the training the priorities of the callbacks are reversed for the load_state() method. Default is 10.

  • name (str, optional) – The name of the callback. If None, the name is set to the class name. Default is None.

  • save_state (bool, optional) – If True, the state of the callback is saved in the checkpoint file. Default is True.

  • load_state (bool, optional) – If True, the state of the callback is loaded from the checkpoint file. Default is equal to save_state.

class neurotorch.callbacks.early_stopping.EarlyStoppingOnNaN(metric: str, **kwargs)

Bases: BaseCallback

Monitor the training process and set the stop_training_flag to True when the metric is NaN.

DEFAULT_PRIORITY = 100
__init__(metric: str, **kwargs)

Constructor for EarlyStoppingOnNaN class.

Parameters:
  • metric (str) – Name of the metric to monitor.

  • kwargs – The keyword arguments to pass to the BaseCallback.

extra_repr() str
on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

class neurotorch.callbacks.early_stopping.EarlyStoppingOnStagnation(metric: str, patience: int = 10, tol: float = 0.0001, start_with_history: bool = True, **kwargs)

Bases: BaseCallback

Monitor the training process and set the stop_training_flag to True when the metric stagnates. The metric is considered to stagnate when the mean of the absolute difference between the last patience iterations is less than tol.

DEFAULT_PRIORITY = 100
__init__(metric: str, patience: int = 10, tol: float = 0.0001, start_with_history: bool = True, **kwargs)

Constructor for EarlyStoppingOnStagnation class.

Parameters:
  • metric (str) – Name of the metric to monitor.

  • patience (int) – Number of iterations to wait before stopping the training.

  • tol (float) – The tolerance for the metric.

  • start_with_history (bool) – Whether to start the monitor with the history of the metric.

  • kwargs – The keyword arguments to pass to the BaseCallback.

extra_repr() str
get_value()
on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

start(trainer, **kwargs)

Called when the training starts. This is the first callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

class neurotorch.callbacks.early_stopping.EarlyStoppingOnTimeLimit(*, delta_seconds: float = 600.0, resume_on_load: bool = True, **kwargs)

Bases: BaseCallback

Monitor the training process and set the stop_training_flag to True when the threshold is met.

CURRENT_SECONDS_COUNT_KEY = 'current_seconds_count'
DELTA_SECONDS_KEY = 'delta_seconds'
__init__(*, delta_seconds: float = 600.0, resume_on_load: bool = True, **kwargs)

Constructor for EarlyStoppingThreshold class.

Parameters:
  • delta_seconds (float) – The number of seconds to wait before stopping the training.

  • resume_on_load (bool) – Whether to resume the time when loading a checkpoint. If False, the time will be reset to 0.

  • kwargs – The keyword arguments to pass to the BaseCallback.

extra_repr() str
get_checkpoint_state(trainer, **kwargs) object

Get the state of the callback. This is called when the checkpoint manager saves the state of the trainer. Then this state is saved in the checkpoint file with the name of the callback as the key.

Parameters:

trainer (Trainer) – The trainer.

Returns:

The state of the callback.

Return type:

An pickleable object.

load_checkpoint_state(trainer, checkpoint: dict, **kwargs)

Loads the state of the callback from a dictionary.

Parameters:
  • trainer (Trainer) – The trainer.

  • checkpoint (dict) – The dictionary containing all the states of the trainer.

Returns:

None

on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

start(trainer, **kwargs)

Called when the training starts. This is the first callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

update_flags(trainer, **kwargs)
class neurotorch.callbacks.early_stopping.EarlyStoppingThreshold(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)

Bases: BaseCallback

Monitor the training process and set the stop_training_flag to True when the threshold is met.

__init__(*, metric: str, threshold: float, minimize_metric: bool, **kwargs)

Constructor for EarlyStoppingThreshold class.

Parameters:
  • metric (str) – Name of the metric to monitor.

  • threshold (float) – Threshold value for the metric.

  • minimize_metric (bool) – Whether to minimize or maximize the metric.

  • kwargs – The keyword arguments to pass to the BaseCallback.

on_iteration_end(trainer, **kwargs)

Called when an iteration ends. An iteration is defined as one full pass through the training dataset and the validation dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

neurotorch.callbacks.events module

class neurotorch.callbacks.events.EventOnMetricThreshold(metric_name: str, threshold: float, *, event: Callable, event_args: tuple | None = None, event_kwargs: dict | None = None, minimize_metric: bool = True, do_once: bool = False, **kwargs)

Bases: BaseCallback

This class is a callback that call an event if the metric value reach a given threshold at the end of each iteration. Note that the trainer will be passed as the first argument to the event. In addition, if the given metric is not present in the training state, the event will not be called. Finally, the output of the event will be pass to the on_pbar_update() method.

Attributes:
  • metric_name (str): The name of the metric to monitor.

  • threshold (float): The threshold value.

  • event (Callable): The event to call.

  • event_args (tuple): The arguments to pass to the event.

  • event_kwargs (dict): The keyword arguments to pass to the event.

__init__(metric_name: str, threshold: float, *, event: Callable, event_args: tuple | None = None, event_kwargs: dict | None = None, minimize_metric: bool = True, do_once: bool = False, **kwargs)

Constructor for the EventOnMetricThreshold class.

Parameters:
  • metric_name (str) – The name of the metric to monitor.

  • threshold (float) – The threshold value.

  • event (Callable) – The event to call.

  • event_args (tuple) – The arguments to pass to the event.

  • event_kwargs (dict) – The keyword arguments to pass to the event.

  • kwargs – The keyword arguments to pass to the base class.

on_iteration_end(trainer, **kwargs)

Check if the metric value reach the threshold.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

on_pbar_update(trainer, **kwargs) dict

Update the progress bar with the current learning algorithm.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

The progress bar update.

Return type:

dict

neurotorch.callbacks.history module

class neurotorch.callbacks.history.TrainingHistory(container: Dict[str, List[float]] | None = None, default_value=nan, **kwargs)

Bases: BaseCallback

This class is used to store some metrics over the training process.

Attributes:
  • default_value (float): The default value to use to equalize the lengths of the container’s items.

DEFAULT_PRIORITY = 0
__init__(container: Dict[str, List[float]] | None = None, default_value=nan, **kwargs)

Initialize the container with the given container.

Parameters:
  • container (Dict[str, List[float]]) – The container to initialize the container with.

  • default_value (float) – The default value to use to equalize the lengths of the container’s items.

  • kwargs – The keyword arguments to pass to the BaseCallback.

append(key, value)

Add the given value to the given key. Increase the size of the container by one.

Parameters:
  • key (str) – The key to add the value to.

  • value (float) – The value to add.

Returns:

None

concat(other)
create_plot(**kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]]

Create a plot of the metrics in the container.

Parameters:

kwargs – Keyword arguments.

Keyword Arguments:
  • figsize (Tuple[float, float]) – The size of the figure.

  • linewidth (int) – The width of the lines.

Returns:

The figure, axes and lines of the plot.

Return type:

Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]

extra_repr()
get(key, default=None)

Get the values of the given key.

Parameters:
  • key (str) – The key to get the values of.

  • default (Any) – The default value to return if the key is not in the container.

Returns:

The values of the given key.

Return type:

List[float]

get_item_at(idx: int = -1)

Get all the metrics of the iteration at the given index.

Parameters:

idx (int) – The index to get the metrics of.

Returns:

All the metrics of the iteration at the given index.

Return type:

Dict[str, float]

Raises:

ValueError – If the index is out of bounds.

insert(index: int, other)

Increase the size of the container items to the given index and insert the given other into the container.

Parameters:
  • index (int) – The index to insert the other at.

  • other (Dict[str, float]) – The other to insert.

Returns:

None

items()
keys()
max(key=None, default=-inf)

Get the maximum value of the given key.

Parameters:
  • key (str) – The key to get the maximum value of. If None, the first key is used.

  • default (float) – The default value to return if the key is not in the container.

Returns:

The maximum value of the given key.

Return type:

float

max_item(key=None)

Get all the metrics of the iteration with the maximum value of the given key.

Parameters:

key (str) – The key to get the maximum value of. If None, the first key is used.

Returns:

All the metrics of the iteration with the maximum value of the given key.

Return type:

Dict[str, float]

Raises:

ValueError – If the key is not in the container.

min(key=None, default: float = inf) float

Get the minimum value of the given key.

Parameters:
  • key (str) – The key to get the minimum value of. If None, the first key is used.

  • default (float) – The default value to return if the key is not in the container.

Returns:

The minimum value of the given key.

Return type:

float

min_item(key=None) Dict[str, float]

Get all the metrics of the iteration with the minimum value of the given key.

Parameters:

key (str) – The key to get the minimum value of. If None, the first key is used.

Returns:

All the metrics of the iteration with the minimum value of the given key.

Return type:

Dict[str, float]

Raises:

ValueError – If the key is not in the container.

on_iteration_end(trainer, **kwargs)

Insert the current metrics into the container.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

plot(save_path: str | None = None, show: bool = False, **kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]]

Plot the metrics in the container.

Parameters:
  • save_path (Optional[str]) – The path to save the plot to. If None, the plot is not saved.

  • show (bool) – Whether to show the plot.

  • kwargs – Keyword arguments.

Keyword Arguments:
  • figsize (Tuple[float, float]) – The size of the figure.

  • linewidth (int) – The width of the lines.

  • dpi (int) – The resolution of the figure.

  • block (bool) – Whether to block execution until the plot is closed.

  • close (bool) – Whether to close the plot at the end.

Returns:

The figure, axes and lines of the plot.

Return type:

Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]

update_fig(fig: Figure, axes: Dict[str, Axes], lines: Dict[str, Line2D], **kwargs) Tuple[Figure, Dict[str, Axes], Dict[str, Line2D]]

Update the plot of the metrics in the container.

Parameters:
  • fig (plt.Figure) – The figure to update.

  • axes (Dict[str, plt.Axes]) – The axes to update.

  • lines (Dict[str, plt.Line2D]) – The lines to update.

  • kwargs – Keyword arguments.

Returns:

The figure, axes and lines of the plot.

Return type:

Tuple[plt.Figure, Dict[str, plt.Axes], Dict[str, plt.Line2D]]

neurotorch.callbacks.lr_schedulers module

class neurotorch.callbacks.lr_schedulers.LRSchedulerOnMetric(metric: str, metric_schedule: Iterable[float], *, minimize_metric: bool | None = None, lr_decay: float | Sequence[float] | None = None, min_lr: float | Sequence[float] | None = None, lr_start: float | Sequence[float] | None = None, retain_progress: bool = True, optimizer: Optimizer | None = None, **kwargs)

Bases: BaseCallback

Class to schedule the learning rate of the optimizer based on the metric value. Each time the metric reach the next value of the schedule, the learning rate is multiplied by the given decay. The learning rate is also capped at the given minimum value.

Attributes:
  • metric (str): The metric to use to schedule the learning rate.

  • metric_schedule (Iterable[float]): The schedule of the metric.

  • minimize_metric (bool): Whether to minimize the metric or maximize it.

  • lr_decay (float): The decay factor to use when the metric reach the next value of the schedule.

  • min_lr (float): The minimum learning rate to use.

  • lr_start (float): The learning rate to use at the beginning of the training.

  • lr (float): The current learning rate.

  • retain_progress (bool): If True the current step of the scheduler will only increase when the metric reach

the next value of the schedule. If False, the current step will increase or decrease depending on the metric. - step (int): The current step of the scheduler.

DEFAULT_LR_START = 0.001
DEFAULT_MIN_LR = 1e-12
__init__(metric: str, metric_schedule: Iterable[float], *, minimize_metric: bool | None = None, lr_decay: float | Sequence[float] | None = None, min_lr: float | Sequence[float] | None = None, lr_start: float | Sequence[float] | None = None, retain_progress: bool = True, optimizer: Optimizer | None = None, **kwargs)

Initialize the scheduler with the given metric and metric schedule.

Parameters:
  • metric (str) – The metric to use to schedule the learning rate.

  • metric_schedule (Iterable[float]) – The schedule of the metric.

  • minimize_metric (Optional[bool]) – Whether to minimize the metric or maximize it. If None, infer from the metric schedule.

  • lr_decay – The decay factor to use when the metric reach the next value of the schedule. If None, the

decay is computed automatically as (lr_start - min_lr) / len(metric_schedule). :type lr_decay: Optional[Union[float, Sequence[float]]] :param min_lr: The minimum learning rate to use. :type min_lr: Optional[Union[float, Sequence[float]]] :param lr_start: The learning rate to use at the beginning of the training. If None, the learning rate is get automatically as the learning rate of the first group of the optimizer. :type lr_start: Optional[Union[float, Sequence[float]]] :param retain_progress: If True the current step of the scheduler will only increase when the metric reach the next value of the schedule. If False, the current step will increase or decrease depending on the metric. :type retain_progress: bool :param optimizer: The optimizer whose learning rate will be scheduled. If None, the optimizer is get from the

trainer. Note that in this case the first optimizer of the trainer’s callbacks will be used.

Parameters:

kwargs – The keyword arguments to pass to the BaseCallback.

Keyword Arguments:

log_lr_to_history (bool) – Whether to log the learning rate to the training history. Default: True.

extra_repr() str
on_iteration_end(trainer, **kwargs)

Update the learning rate of the optimizer based on the metric value.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

on_pbar_update(trainer, **kwargs) dict

Return the learning rate to display in the progress bar.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

The dictionary to update the progress bar.

Return type:

dict

start(trainer, **kwargs)

Initialize the learning rate of the optimizer and the lr_start attribute if necessary.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

update_step(last_metric: float) int

Update the current step of the scheduler based on the metric value.

Parameters:

last_metric (float) – The last value of the metric.

Returns:

The new step.

Return type:

int

class neurotorch.callbacks.lr_schedulers.LinearLRScheduler(lr_start: float | Sequence[float], lr_end: float | Sequence[float], n_steps: int, optimizer: Optimizer | None = None, **kwargs)

Bases: BaseCallback

This class is a callback that implements a linear learning rate decay. This is useful to decrease the learning rate over iterations. The learning rate is decreased linearly from lr_start to lr_end over n_steps iterations.

Attributes:
  • lr_start (float): The initial learning rate.

  • lr_end (float): The final learning rate.

  • n_steps (int): The number of steps over which the learning rate is decreased.

  • lr (float): The current learning rate.

  • lr_decay (float): The learning rate decay per step.

DEFAULT_LR_START = 0.001
DEFAULT_MIN_LR = 1e-12
__init__(lr_start: float | Sequence[float], lr_end: float | Sequence[float], n_steps: int, optimizer: Optimizer | None = None, **kwargs)

Construcor for the LinearLRScheduler class.

Parameters:
  • lr_start (Union[float, Sequence[float]]) – The initial learning rate. If a sequence is given, each entry of the sequence is used for each parameter group.

  • lr_end (Union[float, Sequence[float]]) – The final learning rate. If a sequence is given, each entry of the sequence is used for each parameter group.

  • n_steps (int) – The number of steps over which the learning rate is decreased.

extra_repr() str
on_iteration_end(trainer, **kwargs)

Decrease the learning rate linearly.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

start(trainer, **kwargs)

Initialize the learning rate of the optimizer and the lr_start attribute if necessary.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

neurotorch.callbacks.training_visualization module

class neurotorch.callbacks.training_visualization.TrainingHistoryVisualizationCallback(temp_folder: str = '~/temp/', **kwargs)

Bases: BaseCallback

This callback is used to visualize the training history in real time.

Attributes:
  • fig (plt.Figure): The figure to plot.

  • axes (plt.Axes): The axes to plot.

  • lines (list): The list of lines to plot.

__init__(temp_folder: str = '~/temp/', **kwargs)

Create a new callback to visualize the training history.

Parameters:
  • temp_folder (str) – The folder where to save the training history.

  • kwargs – The keyword arguments to pass to the base callback.

close(trainer, **kwargs)

Close the process adn delete the temporary file.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_iteration_end(trainer, **kwargs)

Update the training history.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

start(trainer, **kwargs)

Start the process.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

Module contents

class neurotorch.callbacks.ForesightTimeStepUpdaterOnTarget(**kwargs)

Bases: BaseCallback

Updates the foresight time step of the model to the length of the target sequence. This is useful for models that are trained with a fixed foresight time step, but are evaluated on sequences of different lengths.

DEFAULT_PRIORITY = 0
__init__(**kwargs)

Constructor for ForesightTimeStepUpdaterOnTarget class.

Keyword Arguments:
  • time_steps_multiplier (int) – The multiplier to use to determine the time steps. Defaults to 1.

  • target_skip_first (bool) – Whether to skip the first time step of the target sequence. Defaults to True.

  • update_val_loss_freq (int) – The frequency at which to update the validation loss. Defaults to 1.

  • start_intensive_val_at (float) – The fraction of the training epochs at which to start intensive validation. An intensive validation is when the validation is performed at each iteration. Defaults to 0.0. If the value is an integer, it is interpreted as the number of iterations at which to start intensive validation.

  • hh_memory_size_strategy (str) –

    The strategy to use to determine the hidden history memory size. The available strategies are:

    • ”out_memory_size”: The hidden history memory size is equal to the output memory size.

    • ”foresight_time_steps”: The hidden history memory size is equal to the foresight time steps.

    • <Number>: The hidden history memory size is equal to the specified number.

    Defaults to “out_memory_size”.

get_hh_memory_size_from_y_batch(y_batch) int
on_batch_begin(trainer, **kwargs)

Called when a batch starts. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_batch_end(trainer, **kwargs)

Called when a batch ends. The batch is defined as one forward pass through the network.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

on_train_end(trainer, **kwargs)

Called when the train phase of an iteration ends. The train phase is defined as a full pass through the training dataset.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None

start(trainer, **kwargs)

Called when the training starts. This is the first callback called.

Parameters:

trainer (Trainer) – The trainer.

Returns:

None