Speech enhancement#
- class src.model.SpeechEnhancer.SpeechEnhancer(model_name: str, use_mask: bool, representation_name: str, representation_dir_name: str, transform_name: list[str] | None, reconstruct_flag: bool, compute_representation: bool, batch_size: int, k: float, tau: float, tau_syn: float, tau_mem: float, tau_syn_out: float, tau_mem_out: float, time_step: float, membrane_threshold: float, decay_input: bool, spiking_mode: str, reset_mode: str, detach_reset: bool, weight_mean: float, weight_std: float, weight_gain: float, weight_init_dist: str, input_dim: int, hidden_dim_list: list, hidden_channels_list: list, kernel_size: tuple, stride: tuple, padding: tuple | None, dilation: tuple, bias: bool, padding_mode: str, pooling_flag: bool, pooling_type: str, use_same_layer: bool, recurrent_flag: bool, neuron_model: str, train_neuron_parameters: bool, neuron_parameters_init_dist: str, upsample_mode: str, scale_flag: bool, scale_factor: float, bn_flag: bool, dropout_flag: bool, dropout_p: float, skip_connection_type: str, nb_residual_block: int, residual_skip_connection_type: str, use_intermediate_output: bool, loss_name: list, loss_weight: list, loss_bias: list, surrogate_name: str, surrogate_scale: float | None, activation_fn: list[str], optimizer_name: str, learning_rate: float, betas: tuple, scheduler_name: str | None, lr_scheduler_max_lr: float, lr_scheduler_gamma: float, clip_type: str | None, clip_value: float, nb_epochs: int, nb_warmup_epochs: int, nb_steps_bin: int | None, truncated_bptt_ratio: int | None, pretrained_flag: bool, model_file_name_id: int | None, save_mem: bool, save_model: bool, perceptual_metric_flag: bool, train_flag: bool, hpc_flag: bool, use_ddp: bool, dist_backend: str, evaluate_flag: bool, use_amp: bool, use_zero: bool, pin_memory: bool, num_workers: int, prefetch_factor: int, deterministic: bool, seed: int, empty_cache: bool, debug_flag: bool, comet_ml_params: dict | None)#
Bases:
object
Class that implements the speech enhancer.
- experiment_tags() None #
Method that logs hyperparameters to Comet ML experiment tags.
- get_params_dict()#
Method that saves hyperparameters dictionary.
- plot_input_data(dataset_manager: DatasetManager, show_flag: bool = False, verbose: bool = False) None #
Method that plots signal data for multiple signals.
- Parameters:
dataset_manager (DatasetManager) -- Dataset manager instance.
show_flag (bool) -- Boolean that indicates weather to display plots.
verbose (bool) -- Boolean that indicates weather to print specific output.
- plot_output_data(dataset_manager: DatasetManager, show_flag: bool = False) None #
Method that plots signal data for output (enhanced) test signals.
- Parameters:
dataset_manager (DatasetManager) -- Dataset manager instance.
show_flag (bool) -- Boolean that indicates weather to display plots.
- plot_weights(show_flag: bool = False) None #
Method that plots LCA weights.
- Parameters:
show_flag (bool) -- Boolean that indicates weather to display plots.
- prepare_data_manager() None #
Method that defines the dataloader instance.
- prepare_model() None #
Method that defines the model instance.
- save_json_file(file_name: str, file_rec_dict: dict) None #
Method that saves file_rec_dict within a json file.
- Parameters:
file_name (str) -- File name.
file_rec_dict (dict) -- Dictionary to be saved to a file.
- setup(verbose: bool = True) None #
Method that defines DDP instance.
- Parameters:
verbose (bool) -- Boolean that indicates weather to print specific output.
- visualize_data(dist_flag: bool = True, show_flag: bool = False, verbose: bool = False, n: int = 50, epsilon: float = 1e-08) None #
Method that plots data.
- Parameters:
dist_flag (bool) -- Boolean that indicates weather to compute input data distribution plot.
show_flag (bool) -- Boolean that indicates weather to display plots.
verbose (bool) -- Boolean that indicates weather to print specific output.
n (int) -- Number of data subset for data distribution computation.
epsilon (float) -- A small value to avoid computation error.