• Stars
    star
    221
  • Rank 178,819 (Top 4 %)
  • Language
    Python
  • Created over 2 years ago
  • Updated 12 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Control adaptive filters with neural networks.

Meta-AF: Meta-Learning for Adaptive Filters

Jonah Casebeer1*, Nicholas J. Bryan2, and Paris Smaragdis1

1 Department of Computer Science, University of Illinois at Urbana-Champaign
2 Adobe Research, Lead advisor
*Work performed while an intern at Adobe Research

Demo Video

Table of Contents

Abstract

Adaptive filtering algorithms are pervasive throughout signal processing and have had a material impact on a wide variety of domains including audio processing, telecommunications, biomedical sensing, astrophysics and cosmology, seismology, and many more. Adaptive filters typically operate via specialized online, iterative optimization methods such as least-mean squares or recursive least squares and aim to process signals in unknown or nonstationary environments. Such algorithms, however, can be slow and laborious to develop, require domain expertise to create, and necessitate mathematical insight for improvement. In this work, we seek to improve upon hand-derived adaptive filter algorithms and present a comprehensive framework for learning online, adaptive signal processing algorithms or update rules directly from data. To do so, we frame the development of adaptive filters as a meta-learning problem in the context of deep learning and use a form of self-supervision to learn online iterative update rules for adaptive filters. To demonstrate our approach, we focus on audio applications and systematically develop meta-learned adaptive filters for five canonical audio problems including system identification, acoustic echo cancellation, blind equalization, multi-channel dereverberation, and beamforming. We compare our approach against common baselines and/or recent state-of-the-art methods. We show we can learn high-performing adaptive filters that operate in real-time and, in most cases, significantly outperform each method we compare against -- all using a single general-purpose configuration of our approach.

For more details, please see: "Meta-AF: Meta-Learning for Adaptive Filters", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, arXiv, 2022. Or, our talk:

Lecture Video

If you use ideas or code from this work, please cite our paper:

@article{casebeer2022meta,
  title={Meta-AF: Meta-Learning for Adaptive Filters},
  author={Casebeer, Jonah and Bryan, Nicholas J and Smaragdis, Paris},
  journal={arXiv preprint arXiv:2204.11942},
  year={2022}
}

Demos

For audio demonstrations of the work and metaaf package in action, please check out our demo website. You'll be able to find demos for the five core adaptive filtering problems.

Code

We open source all code for the work via our metaaf python pip package. Our metaaf package has functionality which enables meta-learning optimizers for near-arbitrary adaptive filters for any differentiable objective. metaaf automatically manages online overlap-save and overlap-add for single/multi channel and single/multi frame filters. We also include generic implementations of LMS, RMSProp, NLMS, and RLS for benchmarking purposes. Finally, metaaf includes implementation of generic GRU based optimizers, which are compatible with any filter defined in the metaaf format. Below, you can find example usage, usage for several common adaptive filter tasks (in the adaptive filter zoo), and installation instructions.

The metaaf package is relatively small, being limited to a dozen files which enable much more functionality than we demo here. The core meta-learning code is in core.py, the buffered and online filter implementations are in filter.py, and the RNN based optimizers are in optimizer_gru.py and optimizer_fgru.py. The remaining files hold utilities and generic implementations of baseline optimizers. meta.py contains a class for managing training.

Installation

To install the metaaf python package, you will need a working JAX install. You can set one up by following the official directions here. Below is an example of the commands we use to setup a new conda environment called metaenv in which we install metaaf and any dependencies.

GPU Setup

### GPU
# Install all the cuda and cudnn prerequisites
conda create -yn metaenv python=3.7 &&
conda install -yn metaenv cudatoolkit=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudatoolkit-dev=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudnn=8.2 -c nvidia -c pytorch -c anaconda -c conda-forge &&
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv

# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install jax[cuda11_cudnn82]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html  

# Install Haiku
pip install git+https://github.com/deepmind/[email protected]

CPU Setup

### CPU. x86 only
conda create -yn metaenv python=3.7 && 
conda install -yn metaenv pytorch torchvision torchaudio -c pytorch && 
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv

# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install --upgrade pip
pip install --upgrade "jax[cpu]"==0.3.15

# Install Haiku
pip install git+https://github.com/deepmind/[email protected]

Finally, with the prerequisites done, you can install metaaf by cloning the repo, moving into the base directory, and running pip install -e ./. This pip install adds the remaining dependencies. To run the demo notebook, you also need to:

# Add the conda env to your jupyter session
conda install -yn metaenv ipykernel 
ipython kernel install --user --name=metaenv

# Install plotting
pip install matplotlib

# Install widgets for a progress bar
pip install ipywidgets

Example Usage

The metaaf package provides several important modules to facilitate training. The first is the MetaAFTrainer, a class which manages training. To use the MetaAFTrainer, we need to define a filter architecture, and a dataset. metaaf adopts several conventions to simplify training and automate procedures like buffering. In this notebook, we walk through this process and demonstrate on a toy system-identification task. In this section, we explain that toy-task and the automatic argparse utilities. To see a full-scale example, proceed to the next section, where we describe the Meta-AF Zoo.

First, you need to make a datatset using a regular PyTorch dataset. The dataset must return a dictionary with two keys: "signals" and "metadata". The "signals" are automatically indexed and sliced and should be of size samples by channels.

class SystemIDDataset(Dataset):
    def __init__(self, N=4096, sys_order=32):
        self.N = N
        self.sys_order = sys_order

    def __len__(self):
        return 256

    def __getitem__(self, idx):
        # the system
        w = np.random.normal(size=self.sys_order) / self.sys_order

        # the input
        u = np.random.normal(size=self.N)

        # the output
        d = np.convolve(w, u)[: self.N]

        return {
            "signals": {
                "u": u[:, None], # time X channels
                "d": d[:, None], # time X channels
            },  
            "metadata": {},
        }
train_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
val_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
test_loader = NumpyLoader(SystemIDDataset(), batch_size=32)

Then, you define your filter. We're going to inherit from the metaaf OLS module. When inheriting, you can return either the current result, which will be automatically buffered, or a dictionary. When returning a dictionary it must have a key "out" which will be buffered. All other keys are stacked and returned.

from metaaf.filter import OverlapSave
# the filter inherits from the overlap save modules
class SystemID(OverlapSave, hk.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # select the analysize window
        self.analysis_window = jnp.ones(self.window_size)

    # Since we use the OLS base class, x and y are stft domain inputs.
    # The filter msut take the same inputs provided in its _fwd function.
    def __ols_call__(self, u, d, metadata):
        # collect a buffer sized anti-aliased filter
        w = self.get_filter(name="w")

        # this is n_frames x n_freq x channels or 1 x F x 1 here
        return (w * u)[0]
    
    @staticmethod
    def add_args(parent_parser):
        return super(SystemID, SystemID).add_args(parent_parser)

    @staticmethod
    def grab_args(kwargs):
        return super(SystemID, SystemID).grab_args(kwargs)

Haiku converts objects to functions. We need to provide a wrapper to do this. The wrapper function MUST take as input the same named values from your dataset.

def _SystemID_fwd(u, d, metadata=None, init_data=None, **kwargs):
    f = SystemID(**kwargs)
    return f(u=u, d=d)

Then, we define an adaptive filter loss. Here, just the MSE. An adaptive filter loss must be written in this form, so that metaaf can automatically take its gradient and pass it around.

def filter_loss(out, data_samples, metadata):
    e =  out - data_samples["d"]
    return jnp.vdot(e, e) / (e.size)

We can construct the meta-train and meta-val losses in a similar fashion.

def meta_train_loss(losses, outputs, data_samples, metadata, outer_index, outer_learnable):
    out = jnp.concatenate(outputs["out"], 0)
    return jnp.log(jnp.mean(jnp.abs(out - data_samples["d"]) ** 2) +  1e-9)

def meta_val_loss(losses, outputs, data_samples, metadata, outer_learnable):
    out = jnp.reshape(
        outputs["out"],
        (outputs["out"].shape[0], -1, outputs["out"].shape[-1]),
    )
    d = data_samples["d"]
    min_len = min(out.shape[1], d.shape[1])
    return jnp.log(jnp.mean(jnp.abs(out[:, :min_len] - d[:, :min_len]) ** 2) +  1e-9)

With everything defined, we can setup the Meta-Trainer and start training.

from metaaf.optimizer_gru import EGRU

# Collect arguments
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser = EGRU.add_args(parser)
parser = SystemID.add_args(parser)
parser = MetaAFTrainer.add_args(parser)
kwargs = vars(parser.parse_args())

# Setup trainer
system = MetaAFTrainer(
    _filter_fwd=_SystemID_fwd,
    filter_kwargs=SystemID.grab_args(kwargs),
    filter_loss=filter_loss,
    meta_train_loss=meta_train_loss,
    meta_val_loss=meta_val_loss,
    optimizer_kwargs=EGRU.grab_args(kwargs),
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)
# Train
key = jax.random.PRNGKey(0)
outer_learned, losses = system.train(
    **MetaAFTrainer.grab_args(kwargs),
    key=key,
)

That is it! For more advanced options check out the zoo, where we demonstrate call backs, customized filters, and more.

Meta-AF Zoo

The Meta-AF Zoo contains implementations for system identification, acoustic echo cancellation, equalization, weighted predection error dereverberation, and a generalized sidelobe canceller beamformer all in the metaaf framework. You can find intructions for how to run, evaluate, and setup those models here. For trained weights, and tuned baselines, please see the tagged release zip file here.

License

All core utility code within the metaaf folder is licensed via the University of Illinois Open Source License. All code within the zoo folder and model weights are licensed via the Adobe Research License. Copyright (c) Adobe Systems Incorporated. All rights reserved.

Related Works

An extension of this work using metaaf here:

"Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies", Junkai Wu, Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, IWAENC, 2022.

@article{wu2022metalearning,
  title={Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies},
  author={Wu, Junkai and Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},    
  booktitle={IEEE International Workshop on Acoustic Signal Enhancement (IWAENC)},
  year={2022},
}

An early version of this work:

"Auto-DSP: Learning to Optimize Acoustic Echo Cancellers", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, WASPAA, 2021.

@inproceedings{casebeer2021auto,
  title={Auto-DSP: Learning to Optimize Acoustic Echo Cancellers},
  author={Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},
  booktitle={2021 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
  pages={291--295},
  year={2021},
  organization={IEEE}
}

More Repositories

1

custom-diffusion

Custom Diffusion: Multi-Concept Customization of Text-to-Image Diffusion (CVPR 2023)
Python
1,835
star
2

theseus

A pretty darn cool JavaScript debugger for Brackets
JavaScript
1,337
star
3

MakeItTalk

Jupyter Notebook
481
star
4

DeepAFx-ST

DeepAFx-ST - Style transfer of audio effects with differentiable signal processing. Please see https://csteinmetz1.github.io/DeepAFx-ST/
Python
352
star
5

spindle

Next-generation web analytics processing with Scala, Spark, and Parquet.
JavaScript
332
star
6

diffusion-rig

Code Release for DiffusionRig (CVPR 2023)
Python
259
star
7

DeepAFx

Third-party audio effects plugins as differentiable layers within deep neural networks.
Jupyter Notebook
185
star
8

ActionScript4

ActionScript 4 specification archive
TeX
181
star
9

sam_inversion

[CVPR 2022] GAN inversion and editing with spatially-adaptive multiple latent layers
Python
169
star
10

affordance-insertion

Python
135
star
11

convmelspec

Convmelspec: Convertible Melspectrograms via 1D Convolutions
Python
128
star
12

MagicFixup

Python
125
star
13

VideoDoodles

Python
119
star
14

fondue

JavaScript instrumentation library for collecting traces
JavaScript
110
star
15

libkafka

A C++ client library for Apache Kafka v0.8+. Also includes C API.
C++
90
star
16

domain-expansion

Domain Expansion of Image Generators - CVPR23
Python
86
star
17

deft_corpus

The Definition Extraction From Text corpus and relevant formatting scripts
Python
79
star
18

node-theseus

JavaScript
76
star
19

GCview

GC / memory management visualization and monitoring framework.
JavaScript
73
star
20

vaw_dataset

This repository provides data for the VAW dataset as described in the CVPR 2021 paper titled "Learning to Predict Visual Attributes in the Wild" and the ECCV 2022 paper titled "Improving Closed and Open-Vocabulary Attribute Prediction using Transformers"
Python
61
star
21

svgObjectModelGenerator

SVG OM Generator & Writer
JavaScript
49
star
22

spark-parquet-thrift-example

Example Spark project using Parquet as a columnar store with Thrift objects.
Scala
48
star
23

spark-cluster-deployment

Automates Spark standalone cluster tasks with Puppet and Fabric.
Python
43
star
24

EntitySeg-Dataset

Adobe-EntitySeg dataset
38
star
25

spark-gpu

GPU Acceleration for Apache Spark
Python
34
star
26

layered-depth-refinement

Python
32
star
27

auto-wire-removal

28
star
28

sunstage

Python
28
star
29

deep-acoustic-analysis

Python
26
star
30

mesh

General-purpose programming language featuring functional idioms, strong static inferred types, and a concurrency model built on managed mutability and STM.
26
star
31

AutoToon

Python
25
star
32

VideoSham-dataset

22
star
33

CHART-Synthetic

Synthetic Dataset used in the ICDAR2019 Competition on HArvesting Raw Tables from Infographics (CHART-Infographics)
Python
19
star
34

DiffusionHandles

Diffusion Handles is a training-free method that enables 3D-aware image edits using a pre-trained Diffusion Model.
Python
15
star
35

Cross-lingual-Test-Dataset-XTD10

13
star
36

beacon-aug

Cross-library augmentation toolbox supporting 300 operators over 8 libraries + AI transforms
Jupyter Notebook
12
star
37

audio-retargeting

C
11
star
38

prometheus-opentsdb-exporter

A Prometheus exporter component for OpenTSDB
Scala
10
star
39

cross-preferences

Java Preferences SPI implementations backed by distributed configuration stores (web API included)
Java
8
star
40

aesop

AESOP: Abstract Encoding of Stories, Objects and Pictures
Python
7
star
41

meetingqa

Python
7
star
42

UniHuman

Python
7
star
43

mississippi

Mississippi is a Python package that runs batch jobs in the Amazon Web Services (AWS) environment.
6
star
44

http_streaming_client

Ruby HTTP client with support for HTTP 1.1 streaming, GZIP compressed streams, and chunked transfer encoding. Includes extensible OAuth support for the Adobe Analytics Firehose and Twitter Streaming APIs.
Ruby
6
star
45

DocEdit-Dataset

Release of the DocEdit Dataset associated with the AAAI 2023 paper "DocEdit: Language-guided Document Editing"
5
star
46

longmoment-detr

Python
5
star
47

LexDeMod

3
star
48

hw_with_style

Python
2
star
49

AutoForecast_ResourceUsageData

2
star
50

ASWValData

Jupyter Notebook
1
star