• Stars
    star
    194
  • Rank 200,219 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 3 years ago
  • Updated 5 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation

Multi-channel Speech Separation, Denoising and Dereverberation

The official repo of:
[1] Changsheng Quan, Xiaofei Li. Multi-channel Narrow-band Deep Speech Separation with Full-band Permutation Invariant Training. In ICASSP 2022.
[2] Changsheng Quan, Xiaofei Li. Multichannel Speech Separation with Narrow-band Conformer. In Interspeech 2022.
[3] Changsheng Quan, Xiaofei Li. NBC2: Multichannel Speech Separation with Revised Narrow-band Conformer. arXiv:2212.02076.
[4] Changsheng Quan, Xiaofei Li. SpatialNet: Extensively Learning Spatial Information for Multichannel Joint Speech Separation, Denoising and Dereverberation. arXiv:2307.16516.

Audio examples can be found at https://audio.westlake.edu.cn/Research/nbss.htm and https://audio.westlake.edu.cn/Research/SpatialNet.htm. More information about our group can be found at https://audio.westlake.edu.cn.

Requirements

pip install -r requirements.txt

# gpuRIR: check https://github.com/DavidDiazGuerra/gpuRIR

Generate Dataset SMS-WSJ-Plus

Generate rirs for the dataset SMS-WSJ_plus used in SpatialNet ablation experiment.

CUDA_VISIBLE_DEVICES=0 python generate_rirs.py --rir_dir ~/datasets/SMS_WSJ_Plus_rirs --save_to configs/datasets/sms_wsj_rir_cfg.npz
cp configs/datasets/sms_wsj_plus_diffuse.npz ~/datasets/SMS_WSJ_Plus_rirs/diffuse.npz # copy diffuse parameters

For SMS-WSJ, please see https://github.com/fgnt/sms_wsj

Train & Test

This project is built on the pytorch-lightning package, in particular its command line interface (CLI). Thus we recommond you to have some knowledge about the CLI in lightning. For Chinese user, you can learn CLI & lightning with this begining project pytorch_lightning_template_for_beginners.

Train SpatialNet on the 0-th GPU with network config file configs/SpatialNet.yaml and dataset config file configs/datasets/sms_wsj_plus.yaml (replace the rir & clean speech dir before training).

python SharedTrainer.py fit \
 --config=configs/SpatialNet.yaml \ # network config
 --config=configs/datasets/sms_wsj_plus.yaml \ # dataset config
 --model.channels=[0,1,2,3,4,5] \ # the channels used
 --model.arch.dim_input=12 \ # input dim per T-F point, i.e. 2 * the number of channels
 --model.arch.dim_output=4 \ # output dim per T-F point, i.e. 2 * the number of sources
 --model.arch.num_freqs=129 \ # the number of frequencies, related to model.stft.n_fft
 --trainer.precision=bf16-mixed \ # mixed precision training, can also be 16-mixed or 32, where 32 can produce the best performance
 --model.compile=true \ # compile the network, requires torch>=2.0. the compiled model is trained much faster
 --data.batch_size=[2,4] \ # batch size for train and val
 --trainer.devices=0, \
 --trainer.max_epochs=100

More gpus can be used by appending the gpu indexes to trainer.devices, e.g. --trainer.devices=0,1,2,3,.

Resume training from a checkpoint:

python SharedTrainer.py fit --config=logs/SpatialNet/version_x/config.yaml \
 --data.batch_size=[2,2] \
 --trainer.devices=0, \ 
 --ckpt_path=logs/SpatialNet/version_x/checkpoints/last.ckpt

where version_x should be replaced with the version you want to resume.

Test the model trained:

python SharedTrainer.py test --config=logs/SpatialNet/version_x/config.yaml \ 
 --ckpt_path=logs/SpatialNet/version_x/checkpoints/epochY_neg_si_sdrZ.ckpt \ 
 --trainer.devices=0,

Module Version

network file
SpatialNet [4] models/arch/SpatialNet.py
NB-BLSTM [1] models/arch/NBSS.py
NBC [2] models/arch/NBSS.py
NBC2 [3] models/arch/NBSS.py

Note

The dataset generation & training commands for the NB-BLSTM/NBC/NBC2 are available in the NBSS branch.

More Repositories

1

FullSubNet

PyTorch implementation of "FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."
Python
529
star
2

McNet

The official repo: "McNet: Fuse Multiple Cues for Multichannel Speech Enhancement", ICASSP 2023
Python
96
star
3

audiossl

A library built for easier audio self-supervised training, downstream tasks evaluation
Python
84
star
4

FN-SSL

The Official PyTorch Implementation of FN-SSL & IPDnet for Sound Source Localization
Python
82
star
5

ATST-SED

This repo includes the official implementations of "Fine-tune the pretrained ATST model for sound event detection".
Jupyter Notebook
76
star
6

FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors". [ICASSP 2024]
Python
72
star
7

RealMAN

A Real-Recorded and Annotated Microphone Array Dataset for Dynamic Speech Enhancement and Localization
Python
58
star
8

RVAE-EM

Official PyTorch implementation of "RVAE-EM: Generative speech dereverberation based on recurrent variational auto-encoder and convolutive transfer function" [ICASSP2024]
Python
38
star
9

pytorch_lightning_template_for_beginners

A pytorch template for beginners based on pytorch_lightning
Python
33
star
10

Narrowband_DeepFiltering

Python
19
star
11

UMA-ASR

This repository is the official implementation of "Unimodal Aggregation for CTC-based Speech Recognition".
Shell
13
star
12

RCT

This repo gives the code for the official implementation of RCT.
Python
12
star
13

OnlineSSL_DPRTF_EG

MATLAB
8
star
14

LSTM-noisePSD

Python
8
star
15

bss_ctf_lasso

MATLAB
5
star
16

Microphone-Array-Generalization-for-Multichannel-Narrowband-Deep-Speech-Enhancement-

Python
3
star
17

Audio-WestlakeU.github.io

Audio and Signal Information Processing Lab in Westlake University concentrates on speech processing algorithm
3
star
18

DP_RTF_SSL

MATLAB
3
star
19

SMIF_online_dereverb

MATLAB
3
star
20

ATST-RCT

ATST-RCT model for DCASE 2022 task4.
Python
2
star
21

RTF_InterFrameSpecSub

MATLAB
1
star
22

RS_noisePSD

MATLAB
1
star