U-Net SNN class#
- class src.model.SpikingModel.UNetSNN(input_dim: int, hidden_channels_list: list, output_dim: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, bias: bool, padding_mode: str, pooling_flag: bool, pooling_type: str, use_same_layer: bool, nb_steps: int, truncated_bptt_ratio: int, spike_fn: SuperSpike | SigmoidDerivative | PiecewiseLinear | ATan, neuron_model: str, neuron_parameters: dict, weight_init: dict, upsample_mode: str = 'bilinear', scale_flag: bool = True, scale_factor: float = 1.0, bn_flag: bool = False, dropout_flag: bool = False, dropout_p: float = 0.5, device: device = 'cpu', dtype: dtype = torch.float32, skip_connection_type: str = 'cat_', use_intermediate_output: bool = False)#
Bases:
SNNBase
Class that implements the SNN model architecture with predefined list of convolutional layers.
- forward(spk: Tensor) Tuple[Tensor, List[Tensor]] #
Method that defines the performed computation during the forward pass.
- spk_skip_connection(spk: Tensor, spk_rec_k_index: int) Tensor #
Method that creates spiking skip connections output.
- Parameters:
spk (torch.Tensor) -- Input spike tensor of corresponding layer.
spk_rec_k_index (int) -- Spike tensor index within spike records.
- training: bool#