Training/Testing#

class src.model.TrainValidTestManager.TrainValidTestManager(model: FCSNN | CSNN | UNetSNN | ResBottleneckUNetSNN, model_name: str, model_file_dir: str | Path, exp_model_file_dir: str | Path, use_mask: bool, dataloader_manager: DataLoader, dataloader_manager_valid: DataLoader | None, representation_name: str, transform_manager: TransformManager, visualization_manager: VisualizationManager, weights: Tensor | None, loss_name: List[str], loss_weight: List[float], loss_bias: List[float], optimizer_name: str, learning_rate: float, betas: tuple, scheduler_name: str | None, lr_scheduler_max_lr: float, lr_scheduler_gamma: float, clip_type: str | None, clip_value: float, nb_epochs: int, nb_warmup_epochs: int, nb_steps: int, nb_steps_bin: int, experiment: Experiment | None = None, save_mem: bool = True, save_model: bool = False, train_neuron_parameters: bool = False, log_lr: bool = True, log_batch_loss: bool = False, log_grad_norm: bool = False, log_hist_flag: bool = False, use_ddp: bool = False, device: device = 'cpu', dtype: dtype = torch.float32, evaluate_flag: bool = False, use_amp: bool = True, use_zero: bool = True, weight_decay: float = 0.0, amsgrad_flag: bool = False, empty_cache: bool = False, pretrained_flag: bool = False, perceptual_metric_flag: bool = False, perceptual_metric_name: str = 'stoi', nb_digits: int = 5)#

Bases: object

Class that implements the training, validation and testing functions.

load_checkpoint(filepath: str | Path) None#

Method that loads model state dictionaries from a specified file directory.

Parameters:

filepath (Union[str, Path]) -- Model state dictionaries file directory.

prepare_optimizer(betas: Tuple[float, float] = (0.9, 0.999), weight_decay: float = 0, amsgrad_flag: bool = False)#

Method that prepares model optimizer.

Parameters:
  • betas (Tuple[float, float]) -- Coefficients used for computing running averages.

  • weight_decay (float) -- Weight decay.

  • amsgrad_flag (bool) -- Boolean that indicates weather to use the AMSGrad variant.

prepare_scheduler(lr_scheduler_max_lr: float, lr_scheduler_gamma: float)#

Method that prepares model scheduler.

Parameters:
  • lr_scheduler_max_lr (float) -- Upper learning rate boundaries in the cycle.

  • lr_scheduler_gamma (float) -- Multiplicative factor of learning rate decay.

save_checkpoint(epoch: int, filepath: str | Path, log_model_flag: bool = False) None#

Method that saves model state dictionaries in a specified file path.

Parameters:
  • epoch (int) -- Last epoch.

  • filepath (Union[str, Path]) -- Model state dictionaries file directory.

  • log_model_flag (bool) -- Boolean that indicates weather to log file to comet ML experiment.

save_coefficients(mem: Tensor, index: Tensor, enhanced_coefficients_dir: str) None#

Method that saves output encoded data.

Parameters:
  • mem (torch.Tensor) -- Enhanced encoded data.

  • index (torch.Tensor) -- File name index tensor.

  • enhanced_coefficients_dir (str) -- Enhanced encoded data directory.

test_model(tensor_enhanced_coefficients_dir: str, coefficients_filename: str, reconstruction_test: bool = False) None#

Method that evaluates the model using the test set.

Parameters:
  • tensor_enhanced_coefficients_dir (str) -- Enhanced encoded data directory.

  • coefficients_filename (str) -- Enhanced encoded data filename.

  • reconstruction_test (str) -- Boolean that indicates weather to run a reconstruction test (output_tensor = target_tensor).

train_model() None#

Method that trains the model and saves the trained model.

validate_model(epoch: int) None#

Method that evaluates the model using the validation set.

Parameters:

epoch (int) -- Training iteration.