Loss#

class src.model.LossManager.LossManager(representation_name: str, loss_name: list[str], loss_weight: list[float], loss_bias: list[float], metric_name: str, weights: Tensor, transform_manager: TransformManager, reduction: str = 'mean', nb_digits: int = 5)#

Bases: Module

Class that implements the loss function.

forward(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

huber_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes Huber loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

l1_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes l1 loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

lsd_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes LSD loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

mse_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes MSE loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

perceptual_metric(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes a perceptual metric as accuracy.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

pesq_fn(rec_input_, rec_target, fs: int = 16000, mode: str = 'wb')#

Method that computes PESQ score function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

si_sdr_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes SI-SDR loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

si_snr_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes SI-SNR loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

stoi_fn(rec_input_, rec_target, fs: int = 16000, extended: bool = False)#

Method that computes STOI score function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

stoi_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes STOI loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

time_mse_loss(input_: Tensor, target: Tensor, input_phase: Tensor | None = None, target_phase: Tensor | None = None) Tensor#

Method that computes time-domain MSE loss function.

Parameters:
  • input (torch.Tensor) -- Model output tensor.

  • target (torch.Tensor) -- target tensor.

  • input_phase (torch.Tensor) -- input phase if encoder is STFT.

  • target_phase (torch.Tensor) -- input phase if encoder is STFT.

training: bool#