Loss#
- class src.model.LossManager.LossManager(representation_name: str, loss_name: list[str], loss_weight: list[float], loss_bias: list[float], metric_name: str, weights: Tensor, transform_manager: TransformManager, reduction: str = 'mean', nb_digits: int = 5)#
Bases:
Module
Class that implements the loss function.
- forward(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- huber_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes Huber loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- l1_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes l1 loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- lsd_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes LSD loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- mse_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes MSE loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- perceptual_metric(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes a perceptual metric as accuracy.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- pesq_fn(rec_input_, rec_target, fs: int = 16000, mode: str = 'wb')#
Method that computes PESQ score function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- si_sdr_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes SI-SDR loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- si_snr_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes SI-SNR loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- stoi_fn(rec_input_, rec_target, fs: int = 16000, extended: bool = False)#
Method that computes STOI score function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- stoi_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes STOI loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- time_mse_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor #
Method that computes time-domain MSE loss function.
- Parameters:
input (torch.Tensor) -- Model output tensor.
target (torch.Tensor) -- target tensor.
input_phase (torch.Tensor) -- input phase if encoder is STFT.
target_phase (torch.Tensor) -- input phase if encoder is STFT.
- training: bool#