U-Net SNN class#

class src.model.SpikingModel.UNetSNN(input_dim: int, hidden_channels_list: list, output_dim: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, bias: bool, padding_mode: str, pooling_flag: bool, pooling_type: str, use_same_layer: bool, nb_steps: int, truncated_bptt_ratio: int, spike_fn: SuperSpike | SigmoidDerivative | PiecewiseLinear | ATan, neuron_model: str, neuron_parameters: dict, weight_init: dict, upsample_mode: str = 'bilinear', scale_flag: bool = True, scale_factor: float = 1.0, bn_flag: bool = False, dropout_flag: bool = False, dropout_p: float = 0.5, device: device = 'cpu', dtype: dtype = torch.float32, skip_connection_type: str = 'cat_', use_intermediate_output: bool = False)#

Bases: SNNBase

Class that implements the SNN model architecture with predefined list of convolutional layers.

forward(spk: Tensor) Tuple[Tensor, List[Tensor]]#

Method that defines the performed computation during the forward pass.

spk_skip_connection(spk: Tensor, spk_rec_k_index: int) Tensor#

Method that creates spiking skip connections output.

Parameters:
  • spk (torch.Tensor) -- Input spike tensor of corresponding layer.

  • spk_rec_k_index (int) -- Spike tensor index within spike records.

training: bool#