Quickstart#
Installation#
Install the packages using pip:
$ pip install -r requirements.txt
Usage#
To train a speech enhancement SNN model, run the following command:
python main.py --model_name UNetSNN --train_flag --nb_epochs 5 --batch_size 4 --learning_rate 0.0004 --train_neuron_parameters --recurrent_flag --detach_reset --use_ddp --pin_memory --empty_cache --evaluate_flag --perceptual_metric_flag --save_mem --deterministic
To test a pretrained speech enhancement SNN model, run the following command:
python main.py --model_name UNetSNN --pretrained_flag --batch_size 4 --train_neuron_parameters --recurrent_flag --detach_reset --use_ddp --pin_memory --empty_cache --evaluate_flag --perceptual_metric_flag --save_mem --deterministic
Note
If input data (for example STFT representation) is not computed, please add the following argument :
--compute_representation
Note
In order to use debug mode, please add the following argument :
--debug_flag
Arguments#
These are arguments that can be set from the command line args.
This program enables user to train/test different models of speech enhancement using spiking neural networks.
usage:
python main.py
Named Arguments#
- -m, --model_name
Possible choices: FCSNN, CSNN, UNetSNN, ResBottleneckUNetSNN, CNN, UNet, ResBottleneckUNet
Name of the model
Default: "UNetSNN"
- -msk, --use_mask
Boolean that indicates weather to use direct approach or masking approach
Default: False
- -rec, --reconstruct_flag
Boolean that indicates weather to enhance noisy speech or reconstruct clean speech
Default: False
- -rn, --representation_name
Representation name
Default: "stft"
- -rdn, --representation_dir_name
Representation directory name
Default: "STFT_4s_nfft=512_wl=512_hl=256"
- -tn, --transform_name
Possible choices: maxabs, quantile_maxabs, normalize, standardize, shift, log_power, log10
List that contains data transform name
Default: ['log_power', 'standardize']
- -cr, --compute_representation
Boolean that indicates weather to compute data
Default: False
- --k
KLIF neuron model k constant
Default: 1.0
- --tau
PLIF neuron model time constant
Default: 20.0
- --tau_syn
Spiking layer input current time constant
Default: 0.00034
- --tau_mem
Spiking layer membrane potential time constant
Default: 0.00034
- --tau_syn_out
Readout layer input current time constant
Default: 0.00034
- --tau_mem_out
Readout layer membrane potential time constant
Default: 0.00034
- --time_step
SNN time step
Default: 0.001
- --membrane_threshold
Spiking layer membrane threshold
Default: 1.0
- -di, --decay_input
Boolean that indicates weather the input will decay
Default: False
- -spk, --spiking_mode
Possible choices: binary, graded
Spiking mode
Default: "binary"
- -rst, --reset_mode
Possible choices: hard_reset, soft_reset
Reset mode: reset to zero (hard_reset) or reset by subtraction (soft_reset)
Default: "soft_reset"
- -dr, --detach_reset
Boolean that indicates weather to detach the computation graph of reset term in backward
Default: False
- -wm, --weight_mean
Mean of weight initialization
Default: 0.0
- -ws, --weight_std
Standard deviation of weight initialization
Default: 0.2
- -wg, --weight_gain
Gain of weight initialization
Default: 5.0
- -wi, --weight_init_dist
Possible choices: normal_, uniform_, kaiming_normal_, kaiming_uniform_, xavier_uniform_
Weight initialization distribution
Default: "normal_"
- -inp, --input_dim
Input layer dimension
Default: 256
- -hl, --hidden_dim_list
List of hidden layers (linear layer) dimension
- -hc, --hidden_channels_list
List of hidden layers (convolutional layer) channels dimension
- -k, --kernel_size
List of convolutional layers kernel size
Default: (3, 3)
- -s, --stride
List of convolutional layers stride
Default: (1, 1)
- -p, --padding
List of convolutional layers padding
- -d, --dilation
List of convolutional layers dilation
Default: (1, 1)
- -bs, --bias
Boolean that indicates weather to use a bias term for convolutional layers
Default: False
- -cpm, --padding_mode
Possible choices: zeros, reflect, replicate, circular
Convolutional layers padding mode
Default: "zeros"
- -pf, --pooling_flag
Boolean that indicates weather to use a pooling layer for downsampling
Default: False
- -pt, --pooling_type
Possible choices: max, avg
Pooling layer type: max, avg
Default: "max"
- -sl, --use_same_layer
Boolean that indicates weather to add layers (input and output layers) with same shape
Default: False
- -r, --recurrent_flag
Boolean that indicates weather to add recurrence term to the input current equation
Default: False
- -nm, --neuron_model
Possible choices: lif, plif, if
Neuron model name
Default: "lif"
- -tnp, --train_neuron_parameters
Boolean that indicates weather to train neuron parameters
Default: False
- -npd, --neuron_parameters_init_dist
Possible choices: constant_, normal_, uniform_
Neuron model parameters initialization distribution
Default: "normal_"
- -up, --upsample_mode
Possible choices: nearest, bilinear
Upsampling mode: nearest, bilinear
Default: "nearest"
- -tmf, --scale_flag
Boolean that indicates weather to train the scaling layer
Default: False
- -mf, --scale_factor
Constant value for the scaling layer
Default: 1.0
- -bnf, --bn_flag
Boolean that indicates weather to add a batch normalization layer
Default: False
- -df, --dropout_flag
Boolean that indicates weather to add a dropout layer
Default: False
- -dp, --dropout_p
Dropout probability of an element to be zeroed
Default: 0.1
- -skp, --skip_connection_type
-
Skip connections type
Default: "cat_"
- -rsd, --nb_residual_block
Number of transitional blocks for the residual block
Default: 1.0
- -rskp, --residual_skip_connection_type
Possible choices: add_, and_, iand_
Residual skip connections type
Default: "add_"
- -out, --use_intermediate_output
Boolean that indicates weather to add rescaled output from intermediate layers
Default: False
- -e, --nb_epochs
Number of training iterations
Default: 30
- -we, --nb_warmup_epochs
Number of training warmup iterations
Default: 1.0
- -b, --batch_size
Number of data samples per batch
Default: 32
- -ts, --nb_steps_bin
Number of forward pass time steps per bin
- -tbptt, --truncated_bptt_ratio
Truncated Backpropagation Through Time (BPTT) ratio over time steps
- -surr, --surrogate_name
Possible choices: SuperSpike, SigmoidDerivative, ATan, PiecewiseLinear
Surrogate gradient function name
Default: "ATan"
- -ss, --surrogate_scale
Surrogate gradient scale
- -act, --activation_fn
Possible choices: None, sigmoid, relu, lrelu, prelu, tanh
List of ANN activation function name: input, hidden, output layers
Default: ['lrelu']
- -ln, --loss_name
Possible choices: mse_loss, l1_loss, lsd_loss, time_mse_loss, huber_loss, stoi_loss, si_snr_loss, si_sdr_loss
List of loss function name
Default: ['mse_loss']
- -lw, --loss_weight
List of loss function weight
Default: [1.0]
- -lb, --loss_bias
List of loss function bias
Default: [0.0]
- -o, --optimizer_name
Optimizer name
Default: "Adam"
- -lr, --learning_rate
Learning rate of the model during training
Default: 0.0002
- -bt, --betas
Optimizer betas parameter
Default: (0.5, 0.9)
- -sch, --scheduler_name
Possible choices: OneCycleLR, StepLR, MultiStepLR, ExponentialLR
Scheduler name
- -lrm, --lr_scheduler_max_lr
Learning rate scheduler max_lr parameter
Default: 0.002
- -lrg, --lr_scheduler_gamma
Learning rate scheduler gamma parameter
Default: 0.8
- -ct, --clip_type
Possible choices: None, value, norm
Gradient clipping type
- -cv, --clip_value
Gradient clipping value
Default: 10.0
- -tr, --train_flag
Boolean that indicates weather to train the model
Default: False
- -pr, --pretrained_flag
Boolean that indicates weather to load a pretrained model
Default: False
- -mid, --model_file_name_id
Pretrained model file id
- -sme, --save_mem
Boolean that indicates weather to log hidden activations
Default: False
- -smd, --save_model
Boolean that indicates weather to save checkpoint of pretrained model
Default: False
- -pmf, --perceptual_metric_flag
Boolean that indicates weather to compute a perceptual metric
Default: False
- -dbg, --debug_flag
Boolean that indicates weather to use debugging dataset
Default: False
- --workspace
Comet ML experiment instance workspace argument
- --api_key
Comet ML experiment instance api_key argument
- --project_name
Comet ML experiment instance project_name argument
- -hpc, --hpc_flag
Boolean that indicates weather to use HPC configuration
Default: False
- -dist, --use_ddp
Boolean that indicates weather to use Pytorch Distributed Data Parallel (DDP) library
Default: False
- -db, --dist_backend
Possible choices: nccl, gloo
DDP backend name
Default: "nccl"
- -eval, --evaluate_flag
Boolean that indicates weather to evaluate model using the validation set during training
Default: False
- -cast, --use_amp
Boolean that indicates weather to use Pytorch Automatic Mixed Precision (AMP) library
Default: False
- -zero, --use_zero
Boolean that indicates weather to use Pytorch Zero Redundancy Optimizer (ZeRO)
Default: False
- -pm, --pin_memory
Boolean that indicates pin_memory parameter for the data loader
Default: False
- -nw, --num_workers
Number of workers parameter for the data loader
Default: 0.0
- -prf, --prefetch_factor
Prefetch factor parameter for the data loader
Default: 2
- -det, --deterministic
Boolean that indicates weather to use deterministic mode for reproducibility
Default: False
- -sd, --seed
Reproducibility seed
Default: 1.0
- -ec, --empty_cache
Boolean that indicates weather to empty cache
Default: False
Experiment tracking using Comet ML#
Sign up on comet.ml, and add the following arguments to command line:
--workspace <Your Workspace> --api_key <Your API Key> --project_name <Your Project Name>