Training/Testing#
- class src.model.TrainValidTestManager.TrainValidTestManager(model: FCSNN | CSNN | UNetSNN | ResBottleneckUNetSNN, model_name: str, model_file_dir: str | Path, exp_model_file_dir: str | Path, use_mask: bool, dataloader_manager: DataLoader, dataloader_manager_valid: DataLoader | None, representation_name: str, transform_manager: TransformManager, visualization_manager: VisualizationManager, weights: Tensor | None, loss_name: List[str], loss_weight: List[float], loss_bias: List[float], optimizer_name: str, learning_rate: float, betas: tuple, scheduler_name: str | None, lr_scheduler_max_lr: float, lr_scheduler_gamma: float, clip_type: str | None, clip_value: float, nb_epochs: int, nb_warmup_epochs: int, nb_steps: int, nb_steps_bin: int, experiment: Experiment | None = None, save_mem: bool = True, save_model: bool = False, train_neuron_parameters: bool = False, log_lr: bool = True, log_batch_loss: bool = False, log_grad_norm: bool = False, log_hist_flag: bool = False, use_ddp: bool = False, device: device = 'cpu', dtype: dtype = torch.float32, evaluate_flag: bool = False, use_amp: bool = True, use_zero: bool = True, weight_decay: float = 0.0, amsgrad_flag: bool = False, empty_cache: bool = False, pretrained_flag: bool = False, perceptual_metric_flag: bool = False, perceptual_metric_name: str = 'stoi', nb_digits: int = 5)#
Bases:
object
Class that implements the training, validation and testing functions.
- load_checkpoint(filepath: str | Path) None #
Method that loads model state dictionaries from a specified file directory.
- Parameters:
filepath (Union[str, Path]) -- Model state dictionaries file directory.
- prepare_optimizer(betas: Tuple[float, float] = (0.9, 0.999), weight_decay: float = 0, amsgrad_flag: bool = False)#
Method that prepares model optimizer.
- Parameters:
betas (Tuple[float, float]) -- Coefficients used for computing running averages.
weight_decay (float) -- Weight decay.
amsgrad_flag (bool) -- Boolean that indicates weather to use the AMSGrad variant.
- prepare_scheduler(lr_scheduler_max_lr: float, lr_scheduler_gamma: float)#
Method that prepares model scheduler.
- Parameters:
lr_scheduler_max_lr (float) -- Upper learning rate boundaries in the cycle.
lr_scheduler_gamma (float) -- Multiplicative factor of learning rate decay.
- save_checkpoint(epoch: int, filepath: str | Path, log_model_flag: bool = False) None #
Method that saves model state dictionaries in a specified file path.
- Parameters:
epoch (int) -- Last epoch.
filepath (Union[str, Path]) -- Model state dictionaries file directory.
log_model_flag (bool) -- Boolean that indicates weather to log file to comet ML experiment.
- save_coefficients(mem: Tensor, index: Tensor, enhanced_coefficients_dir: str) None #
Method that saves output encoded data.
- Parameters:
mem (torch.Tensor) -- Enhanced encoded data.
index (torch.Tensor) -- File name index tensor.
enhanced_coefficients_dir (str) -- Enhanced encoded data directory.
- test_model(tensor_enhanced_coefficients_dir: str, coefficients_filename: str, reconstruction_test: bool = False) None #
Method that evaluates the model using the test set.
- Parameters:
tensor_enhanced_coefficients_dir (str) -- Enhanced encoded data directory.
coefficients_filename (str) -- Enhanced encoded data filename.
reconstruction_test (str) -- Boolean that indicates weather to run a reconstruction test (output_tensor = target_tensor).
- train_model() None #
Method that trains the model and saves the trained model.
- validate_model(epoch: int) None #
Method that evaluates the model using the validation set.
- Parameters:
epoch (int) -- Training iteration.