• Stars
    star
    116
  • Rank 303,894 (Top 6 %)
  • Language
    Jupyter Notebook
  • License
    Apache License 2.0
  • Created over 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official implementation of the Keyword Transformer: https://arxiv.org/abs/2104.00769

Keyword Transformer: A Self-Attention Model for Keyword Spotting

drawing

This is the official repository for the paper Keyword Transformer: A Self-Attention Model for Keyword Spotting, presented at Interspeech 2021. Consider citing our paper if you find this work useful.

@inproceedings{berg21_interspeech,
  author={Axel Berg and Mark Oโ€™Connor and Miguel Tairum Cruz},
  title={{Keyword Transformer: A Self-Attention Model for Keyword Spotting}},
  year=2021,
  booktitle={Proc. Interspeech 2021},
  pages={4249--4253},
  doi={10.21437/Interspeech.2021-1286}
}

Setup

Download Google Speech Commands

There are two versions of the dataset, V1 and V2. To download and extract dataset V2, run:

wget https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
mkdir data2
mv ./speech_commands_v0.02.tar.gz ./data2
cd ./data2
tar -xf ./speech_commands_v0.02.tar.gz
cd ../

And similarly for V1:

wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
mkdir data1
mv ./speech_commands_v0.01.tar.gz ./data1
cd ./data1
tar -xf ./speech_commands_v0.01.tar.gz
cd ../

Install dependencies

Set up a new virtual environment:

pip install virtualenv
virtualenv --system-site-packages -p python3 ./venv3
source ./venv3/bin/activate

To install dependencies, run

pip install -r requirements.txt

Tested using Tensorflow 2.4.0rc1 with CUDA 11.

Note: Installing the correct Tensorflow version is important for reproducibility! Using more recent versions of Tensorflow results in small accuracy differences each time the model is evaluated. This might be due to a change in how the random seed generator is implemented, and therefore changes the sampling of the "unknown" keyword class.

Model

The Keyword-Transformer model is defined here. It takes the mel scale spectrogram as input, which has shape 98 x 40 using the default settings, corresponding to the 98 time windows with 40 frequency coefficients.

There are three variants of the Keyword-Transformer model:

  • Time-domain attention: each time-window is treated as a patch, self-attention is computed between time-windows
  • Frequency-domain attention: each frequency is treated as a patch self-attention is computed between frequencies
  • Combination of both: The signal is fed into both a time- and a frequency-domain transformer and the outputs are combined
  • Patch-wise attention: Similar to the vision transformer, it extracts rectangular patches from the spectrogram, so attention happens both in the time and frequency domain simultaneously.

Training a model from scratch

To train KWT-3 from scratch on Speech Commands V2, run

sh train.sh

Please note that the train directory (given by the argument --train_dir) cannot exist prior to start script.

The model-specific arguments for KWT are:

--num_layers 12 \ #number of sequential transformer encoders
--heads 3 \ #number of attentions heads
--d_model 192 \ #embedding dimension
--mlp_dim 768 \ #mlp-dimension
--dropout1 0. \ #dropout in mlp/multi-head attention blocks
--attention_type 'time' \ #attention type: 'time', 'freq', 'both' or 'patch'
--patch_size '1,40' \ #spectrogram patch_size, if patch attention is used
--prenorm False \ # if False, use postnorm

Training with distillation

We employ hard distillation from a convolutional model (Att-MH-RNN), similar to the approach in DeIT.

To train KWT-3 with hard distillation from a pre-trained model, run

sh distill.sh

Run inference using a pre-trained model

Pre-trained weights for KWT-3, KWT-2 and KWT-1 are provided in ./models_data_v2_12_labels.

Model name embedding dim mlp-dim heads depth #params V2-12 accuracy pre-trained
KWT-1 64 128 1 12 607K 97.7 here
KWT-2 128 256 2 12 2.4M 98.2 here
KWT-3 192 768 3 12 5.5M 98.7 here

To perform inference on Google Speech Commands v2 with 12 labels, run

sh eval.sh

Acknowledgements

The code heavily borrows from the KWS streaming work by Google Research. For a more detailed description of the code structure, see the original authors' README.

We also exploit training techniques from DeiT.

We thank the authors for sharing their code. Please consider citing them as well if you use our code.

License

The source files in this repository are released under the Apache 2.0 license.

Some source files are derived from the KWS streaming repository by Google Research. These are also released under the Apache 2.0 license, the text of which can be seen in the LICENSE file on their repository.

More Repositories

1

ComputeLibrary

The Compute Library is a set of computer vision and machine learning functions optimised for both Arm CPUs and GPUs using SIMD technologies.
C++
2,539
star
2

arm-trusted-firmware

Read-only mirror of Trusted Firmware-A
C
1,690
star
3

CMSIS_5

CMSIS Version 5 Development Repository
C
1,327
star
4

armnn

Arm NN ML Software. The code here is a read-only mirror of https://review.mlplatform.org/admin/repos/ml/armnn
C++
1,162
star
5

ML-KWS-for-MCU

Keyword spotting on Arm Cortex-M Microcontrollers
C
1,040
star
6

astc-encoder

The Arm ASTC Encoder, a compressor for the Adaptive Scalable Texture Compression data format.
C
880
star
7

abi-aa

Application Binary Interface for the Armยฎ Architecture
HTML
673
star
8

vulkan_best_practice_for_mobile_developers

Vulkan best practice for mobile developers
C++
564
star
9

CMSIS-FreeRTOS

FreeRTOS adaptation for CMSIS-RTOS Version 2
C
502
star
10

optimized-routines

Optimized implementations of various library functions for ARM architecture processors
C
486
star
11

CMSIS_4

Cortex Microcontroller Software Interface Standard (V4 no longer maintained)
C
451
star
12

mango

Parallel Hyperparameter Tuning in Python
Jupyter Notebook
396
star
13

ML-examples

Arm Machine Learning tutorials and examples
C++
371
star
14

LLVM-embedded-toolchain-for-Arm

A project dedicated to building LLVM toolchain for 32-bit Arm embedded targets.
CMake
331
star
15

opengl-es-sdk-for-android

OpenGL ES SDK for Android
CSS
325
star
16

SCALE-Sim

Python
296
star
17

Arm-2D

2D Graphic Library optimized for Cortex-M processors
C
295
star
18

CMSIS-DSP

CMSIS-DSP embedded compute library for Cortex-M and Cortex-A
C
277
star
19

Tool-Solutions

Tutorials & examples for Arm software development tools.
C
217
star
20

EndpointAI

C++
216
star
21

SCP-firmware

Read-only mirror of System Control Processor (SCP) firmware
C
205
star
22

vulkan-sdk

Github repository for the Vulkan SDK
C
199
star
23

lisa

Linux Integrated System Analysis
Jupyter Notebook
192
star
24

HWCPipe

Hardware counters interface
C++
188
star
25

u-boot

Clone of upstream U-Boot repo with patches for Arm development boards
C
177
star
26

CMSIS-NN

CMSIS-NN Library
C
173
star
27

CMSIS-Driver

Repository of microcontroller peripheral driver implementing the CMSIS-Driver API specification
C
165
star
28

android-nn-driver

C++
151
star
29

CMSIS_6

CMSIS version 6 (successor of CMSIS_5)
C
149
star
30

ML-zoo

Python
149
star
31

workload-automation

A framework for automating workload execution and measurement collection on ARM devices.
Python
138
star
32

gator

Sources for Arm Streamline's gator daemon
C++
121
star
33

ebbr

Embedded Base Boot Requirements Specification
PostScript
113
star
34

perfdoc

A cross-platform Vulkan layer which checks Vulkan applications for best practices on Arm Mali devices.
C++
112
star
35

linux

C
95
star
36

asl-interpreter

Example implementation of Arm's Architecture Specification Language (ASL)
OCaml
94
star
37

MDK-Middleware

MDK-Middleware (file system, network and USB components) source code for Arm Cortex-M using CMSIS-Drivers and CMSIS-RTOS2 APIs.
C
93
star
38

sbsa-acs

ARM Enterprise: SBSA Architecture Compliance Suite
C
88
star
39

sesr

Super-Efficient Super Resolution
Python
87
star
40

mobile-studio-integration-for-unity

Mobile Studio tool integration with C# scripting for the Unity game engine.
C
86
star
41

CSAL

Coresight Access Library
C
78
star
42

progress64

PROGRESS64 is a C library of scalable functions for concurrent programs, primarily focused on networking applications.
C
70
star
43

psa-arch-tests

Tests for verifying implementations of TBSA-v8M and the PSA Certified APIs
C
66
star
44

CMSIS-RTX

RTX5 real time kernel for Arm Cortex-based embedded systems (spin-off from CMSIS_5)
C
64
star
45

Cloud-IoT-Core-Kit-Examples

Example projects and code are supplied to support the Arm-based IoT Kit for Cloud IoT Core
Python
62
star
46

developer

GTM related documentation
C++
61
star
47

cmsis-pack-eclipse

CMSIS-Pack Eclipse Plug-ins
Java
60
star
48

trappy

This repository has moved to https://gitlab.arm.com/tooling/trappy
Python
60
star
49

ethos-n-driver-stack

Driver stack (including user space libraries, kernel module and firmware) for the Armยฎ Ethosโ„ข-N NPU
C++
59
star
50

AVH-GetStarted

DEPRECATED - use instead AVH_CI_Template
C
58
star
51

CMSIS-CV

Computer Vision library for IoT
C++
54
star
52

acle

Arm C Language Extensions (ACLE)
Python
52
star
53

arm-systemready

Arm SystemReady
Shell
52
star
54

patrace

C++
52
star
55

tarmac-trace-utilities

Tools for analyzing and browsing Tarmac instruction traces.
C++
47
star
56

devlib

Library for interaction with and instrumentation of remote devices.
Python
47
star
57

speculation-barrier

This project provides a header file which contains wrapper macros for the __builtin_load_no_speculate builtin function defined at https://www.arm.com/security-update This builtin function defines a speculation barrier, which can be used to limit the conditions under which a value which has been loaded can be used under speculative execution.
Objective-C
44
star
58

arm-enterprise-acs

ARM Enterprise ACS
C
42
star
59

DeepFreeze

SystemVerilog
38
star
60

tf-issues

Issue tracking for the ARM Trusted Firmware project
36
star
61

scalpel

This is a PyTorch implementation of the Scalpel. Node pruning for five benchmark networks and SIMD-aware weight pruning for LeNet-300-100 and LeNet-5 is included.
Python
35
star
62

psa-api

Documentation source and development of the PSA Certified API
C
34
star
63

TZ-TRNG

TrustZone True Number Generator
C
33
star
64

AVH

AVH-FVP: Arm Virtual Hardware - Fixed Virtual Platform
C
32
star
65

CMSIS-View

Repository of CMSIS Software Pack for software event generation and input/output handling.
Go
32
star
66

perf-libs-tools

C
31
star
67

bob-build

Meta-build system using Blueprint and ninja
Go
30
star
68

CMSIS-DAP

CoreSight Debug Access Port (DAP) debug probe protocol reference implementation (spin-off from CMSIS_5)
C
30
star
69

mram_simulation_framework

MRAM magnetization simulation framework. s-LLGS python and verilog-a solvers for transients simulation and Fokker-planck equation solver for stochastic analysis
Python
28
star
70

bento-linker

A light-weight alternative to processes for microcontrollers.
C
27
star
71

toolchain-gnu-bare-metal

A toolchain sub-project dedicated to build GNU toolchain for 32-bit bare-metal targets
Shell
26
star
72

data

Machine-readable data describing Arm architecture and implementations. Includes JSON descriptions of implemented PMU events.
26
star
73

synchronization-benchmarks

Collection of synchronization micro-benchmarks and traces from infrastructure applications
C
26
star
74

libGPUInfo

A utility library for application developers to query the configuration of the Arm Immortalis GPU or Arm Mali GPU present in their system.
C++
24
star
75

cryptocell-312-runtime

CryptoCell 312 runtime code
C
24
star
76

CMSIS-Compiler

CMSIS Compiler support for Arm Compiler
C
24
star
77

vscode-cmsis-csolution

Extension support for VS Code CMSIS Project Extension
24
star
78

libddssec

DDS Security library - Project moved to https://gitlab.arm.com/libraries/libddssec
C
23
star
79

NXP_LPC

CMSIS Driver Implementations for the NXP LPC Microcontroller Series
C
23
star
80

golang-utils

Helpers and utilities for Golang in order to do actions not available in the standard library.
Go
23
star
81

AArch64cryptolib

AArch64cryptolib is a from scratch implementation of cryptographic primitives aiming for optimal performance on Arm A-class cores
C
23
star
82

AVH-TFLmicrospeech

Example: Micro speech for TensorFlow Lite
C
22
star
83

Shackleton-Framework

A generic genetic programming framework that aims to make genetic programming easier for a myriad of uses. Currently, the main target is to use the framework for code optimization in tandem with the LLVM framework.
C
22
star
84

CMSIS-Stream

CMSIS-Stream software component
Python
21
star
85

bart

Behavioural Analysis and Regression Toolkit
Python
20
star
86

PAF

PAF (the Physical Attack Framework) is a framework for analyzing physical attacks: fault injection and side channels
C++
20
star
87

HPCG_for_Arm

C++
20
star
88

armnn-mlperf

Arm mlperf.org benchmark port
C++
20
star
89

coresight-wire-protocol

Coresight Wire Protocol (CSWP) Server/Client and streaming trace examples.
HTML
18
star
90

ATP-Engine

C++
18
star
91

bsa-acs

Arm SystemReady : BSA Architecture Compliance Suite
C
17
star
92

ATS-Keyword

Smart Home Total Solution - Keyword Recognition
C
17
star
93

open-iot-sdk

Open-IoT-SDK - Home of the Total Solution applications.
C
16
star
94

vscode-keil-studio-pack

Extension pack for all VS Code extensions
16
star
95

CMSIS-RTOS2_Validation

Validation test suite for CMSIS-RTOS2 API implementations using Arm Virtual Hardware (AVH).
C
16
star
96

vr-sdk-for-android

VR SDK for Android
CSS
16
star
97

meabo

Multi-purpose multi-phase micro-benchmark
C
15
star
98

avhclient

Arm Virtual Hardware Client
Python
15
star
99

CMSIS-Driver_Validation

Test suite for verifying CMSIS-Driver implementations.
C
15
star
100

Methodology_for_ArmIE_SVE

C++
15
star