End-to-End Tutorial: Wakeword Detector¶
In this tutorial you'll build a complete keyword-spotting model from scratch with the dg SDK: train it on Google's Speech Commands v2 dataset, validate it and benchmark it with Bitweaver.
This is the same pipeline used to produce the wakeword entry in the Model Zoo, published here as a reference implementation. The Zoo will continue to grow with more reference models over time.
What you'll build¶
A 12-class wakeword classifier covering 10 keywords (yes, no, up, down, left, right, on, off, stop, go) plus silence and unknown, small enough to run on a microcontroller. The model is a depthwise-separable CNN over MFCC features, fully quantized via dg's Quant* layers.
Prerequisites¶
- Python 3.10+ and PyTorch 2.6+
- A working CUDA GPU is recommended for training (CPU works but is slow)
- A Bitweaver account if you plan to deploy to hardware. Don't have one yet? Get in touch.
Install the SDK and the audio dependencies this tutorial needs:
Create a working directory with four files; we'll fill them in step by step:
1. The dataset¶
Google's Speech Commands v2 provides ~105k one-second clips of people saying common short words. We map 10 of those words to keyword classes, treat all other words as unknown, and synthesize a silence class from the dataset's background-noise recordings.
Each clip is converted to a 10-coefficient MFCC over a 1-second window: 16 kHz audio → 49 frames × 10 coefficients → tensor shape (1, 1, 49, 10) (batch × channels × time × features).
Paste this into data.py:
import os
import logging
import torch
import torch.nn.functional as F
import torchaudio
import soundfile as sf
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
NUM_SAMPLES = 16000 # 1 second at 16 kHz
_mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=SAMPLE_RATE,
n_mfcc=10,
log_mels=True,
melkwargs=dict(
n_fft=512,
hop_length=320,
n_mels=40,
f_min=20.0,
f_max=4000.0,
center=False,
power=1.0,
),
)
AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}
def load_wav(path):
data, sr = sf.read(path, dtype="float32")
if data.ndim > 1:
data = data.mean(axis=1)
waveform = torch.from_numpy(data).unsqueeze(0)
if sr != SAMPLE_RATE:
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
return waveform
def pad_or_trim(waveform):
if waveform.shape[1] < NUM_SAMPLES:
waveform = F.pad(waveform, (0, NUM_SAMPLES - waveform.shape[1]))
elif waveform.shape[1] > NUM_SAMPLES:
waveform = waveform[:, :NUM_SAMPLES]
return waveform
def compute_mfcc(waveform):
waveform = pad_or_trim(waveform)
mfcc = _mfcc_transform(waveform)
mfcc = mfcc.permute(0, 2, 1)
return mfcc.unsqueeze(1)
def load_from_wavs(data_dir, device="cpu"):
"""Load class-folder-structured WAVs into a TensorDataset.
Expected layout:
data_dir/
class_a/file1.wav, file2.wav, ...
class_b/file3.wav, ...
"""
data_dir = os.path.expanduser(data_dir)
class_names = sorted(
d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d))
)
x_list, y_list = [], []
for class_idx, class_name in enumerate(class_names):
class_dir = os.path.join(data_dir, class_name)
files = sorted(
f for f in os.listdir(class_dir)
if os.path.splitext(f)[1].lower() in AUDIO_EXTENSIONS
)
for fname in files:
waveform = load_wav(os.path.join(class_dir, fname))
x_list.append(compute_mfcc(waveform))
y_list.append(class_idx)
x = torch.cat(x_list, dim=0)
y = torch.tensor(y_list, dtype=torch.long)
return torch.utils.data.TensorDataset(x.to(device), y.half().to(device))
The Speech Commands loader (downloads, caches MFCC tensors, generates silence samples from background noise) is below. It's longer; drop it into the same data.py file:
Full Speech Commands loader (append to data.py)
KEYWORDS = ["yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"]
CLASSES = ["silence", "unknown"] + KEYWORDS # 12 classes
LABEL_TO_IDX = {label: idx for idx, label in enumerate(CLASSES)}
def _map_label(word):
return LABEL_TO_IDX.get(word, LABEL_TO_IDX["unknown"])
def _get_sc_path(root):
return os.path.join(root, "SpeechCommands", "speech_commands_v0.02")
def _ensure_downloaded(root):
sc_path = _get_sc_path(root)
if not os.path.isdir(sc_path):
logger.info("Downloading Speech Commands v2...")
torchaudio.datasets.SPEECHCOMMANDS(root, download=True, subset="testing")
return sc_path
def _get_split_files(sc_path, subset):
all_files = []
for word_dir in sorted(os.listdir(sc_path)):
word_path = os.path.join(sc_path, word_dir)
if not os.path.isdir(word_path) or word_dir.startswith("_"):
continue
for fname in sorted(os.listdir(word_path)):
if fname.endswith(".wav"):
rel_path = os.path.join(word_dir, fname)
all_files.append((os.path.join(word_path, fname), word_dir, rel_path))
val_list, test_list = set(), set()
for name, target in (("validation_list.txt", val_list), ("testing_list.txt", test_list)):
p = os.path.join(sc_path, name)
if os.path.exists(p):
target.update(line.strip() for line in open(p))
result = []
for abs_path, label, rel_path in all_files:
if subset == "testing" and rel_path in test_list:
result.append((abs_path, label))
elif subset == "validation" and rel_path in val_list:
result.append((abs_path, label))
elif subset in ("training", "train+val") and rel_path not in test_list:
if subset == "train+val" or rel_path not in val_list:
result.append((abs_path, label))
return result
def _load_background_noise(sc_path):
noise_dir = os.path.join(sc_path, "_background_noise_")
if not os.path.isdir(noise_dir):
return None
chunks = [load_wav(os.path.join(noise_dir, f))[0]
for f in sorted(os.listdir(noise_dir)) if f.endswith(".wav")]
return torch.cat(chunks) if chunks else None
def _generate_silence_mfccs(sc_path, num_samples):
noise = _load_background_noise(sc_path)
if noise is None or len(noise) < NUM_SAMPLES:
mfccs = compute_mfcc(torch.zeros(1, NUM_SAMPLES)).expand(num_samples, -1, -1, -1)
return mfccs.clone()
x_list = []
for _ in range(num_samples):
start = torch.randint(0, len(noise) - NUM_SAMPLES, (1,)).item()
seg = noise[start:start + NUM_SAMPLES].unsqueeze(0)
x_list.append(compute_mfcc(seg))
return torch.cat(x_list, dim=0)
def _build_and_cache(root, subset):
sc_path = _ensure_downloaded(root)
files = _get_split_files(sc_path, subset)
x_list, y_list, keyword_count = [], [], 0
for filepath, label in files:
waveform = load_wav(filepath)
x_list.append(compute_mfcc(waveform))
y_list.append(_map_label(label))
if label in KEYWORDS:
keyword_count += 1
num_silence = keyword_count // len(KEYWORDS) if keyword_count else 100
x_list.append(_generate_silence_mfccs(sc_path, num_silence))
y_list.extend([LABEL_TO_IDX["silence"]] * num_silence)
x = torch.cat(x_list, dim=0)
y = torch.tensor(y_list, dtype=torch.long)
cache_dir = os.path.join(root, "speech_commands_cache")
os.makedirs(cache_dir, exist_ok=True)
torch.save(x, os.path.join(cache_dir, f"{subset}_x.pt"))
torch.save(y, os.path.join(cache_dir, f"{subset}_y.pt"))
return x, y
def load_speech_commands(root="~/data", subset="training", device="cpu"):
root = os.path.expanduser(root)
cache_dir = os.path.join(root, "speech_commands_cache")
x_path = os.path.join(cache_dir, f"{subset}_x.pt")
y_path = os.path.join(cache_dir, f"{subset}_y.pt")
if os.path.exists(x_path) and os.path.exists(y_path):
x = torch.load(x_path, weights_only=True)
y = torch.load(y_path, weights_only=True)
else:
x, y = _build_and_cache(root, subset)
return torch.utils.data.TensorDataset(x.to(device), y.half().to(device))
The first call to load_speech_commands downloads Speech Commands v2 (~2 GB) into ~/data/SpeechCommands/ and caches the MFCC tensors at ~/data/speech_commands_cache/. Subsequent calls hit the cache.
2. Define the model¶
Paste this into model.py:
import torch.nn as nn
from dg.base_model import DLGModel
from dg.dtype import DType
from dg.layer import (
Flatten,
Norm,
QuantAvgPool2d,
QuantConv2d,
QuantDepthwiseConv2d,
QuantLinear,
)
class Model(DLGModel):
NORM_MEAN = -1.6787 # MFCC mean over training set
NORM_STD = 7.6337 # MFCC std over training set
def __init__(self, num_classes=12):
super().__init__()
self.num_classes = num_classes
self.apply(self._init_weights)
@staticmethod
def _init_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
if m.bias is not None:
nn.init.zeros_(m.bias)
def build_model_graph(self):
return [
Norm(mean=self.NORM_MEAN, std=self.NORM_STD, dtype_in=DType.FLOAT32),
QuantConv2d(1, 64, (10, 4), (2, 2), (1, 1, 4, 5), bias=True, act_func="relu", bn=True),
*self._ds_block(64),
*self._ds_block(64),
*self._ds_block(64),
*self._ds_block(64),
QuantAvgPool2d((25, 5), (25, 5), (0, 0)),
Flatten(),
QuantLinear(64, self.num_classes, bias=True),
]
def _ds_block(self, channels):
return [
QuantDepthwiseConv2d(channels, (3, 3), (1, 1), (1, 1), bias=True, act_func="relu", bn=True),
QuantConv2d(channels, channels, (1, 1), (1, 1), (0, 0), bias=True, act_func="relu", bn=True),
]
What's happening:
Normanchors the input dtype asFLOAT32and standardizes the MFCC features using precomputed dataset statistics.- First
QuantConv2dis a strided 10×4 convolution that downsamples the time axis and lifts to 64 channels. - Four depthwise-separable blocks: each is a 3×3 depthwise conv followed by a 1×1 pointwise conv. This is the MobileNet pattern: cheap on parameters and fast on MCUs.
QuantAvgPool2dcollapses the remaining 25×5 spatial grid to a single feature vector.Flatten+QuantLinearproduces the 12-class logits.
Every learnable layer is a Quant* variant, which means the model is quantization-aware from the start; no separate calibration pass is needed before deployment.
3. Train it¶
Paste this into train.py:
import torch
from dg import eval as dg_eval, train
from model import Model
from data import load_speech_commands
device = "cpu" # set to "cuda" or "mps" if available
train_dataset = load_speech_commands(subset="train+val", device=device)
test_dataset = load_speech_commands(subset="testing", device=device)
torch.manual_seed(42)
model = Model()
train.Trainer(
model, train_dataset, test_dataset,
val_epoch=1, epochs=30, batch_size=128, lr=0.001,
optimizer=torch.optim.Adam,
optimizer_kwargs={"weight_decay": 0.01},
device=device,
).train()
train_acc = dg_eval.get_acc(model, train_dataset, batch_size=128, device=device)
test_acc = dg_eval.get_acc(model, test_dataset, batch_size=128, device=device)
print(f"Train acc: {100 * train_acc:.2f}%")
print(f"Test acc: {100 * test_acc:.2f}%")
model.save_pretrained("./pretrained")
Run it:
First run downloads Speech Commands and computes MFCCs (~5 minutes on a fast disk). Training runs for 30 epochs with Adam, weight decay 0.01, learning rate 1e-3.
Expected results after 30 epochs:
save_pretrained("./pretrained") writes four artifacts:
| File | Purpose |
|---|---|
model.pt |
Trained weights |
model.py |
Self-contained architecture (copied from your source file) |
config.json |
Constructor kwargs for Model(**config) |
schema.json |
I/O contract and layer layout, what Bitweaver compiles |
4. Validate locally¶
Before sending the model to hardware, sanity-check that inference works. Paste into validate.py:
import dg
from data import load_speech_commands
device = "cpu" # set to "cuda" or "mps" if available
test_dataset = load_speech_commands(subset="testing", device=device)
model = dg.from_pretrained("./pretrained", device=device)
x, y = test_dataset[0]
pred = model(x.unsqueeze(0)).argmax(dim=1).item()
print(f"Predicted: {pred}, Actual: {int(y)}")
Run it:
dg.from_pretrained accepts either a hub model ID (e.g. "wakeword") or a local directory. Here we point it at the ./pretrained we just wrote, so it loads our freshly-trained weights via model.py + config.json + model.pt.
5. Compile & benchmark on hardware¶
You now have ./pretrained/schema.json. Continue to Compile & Benchmark:
- Create a Bitweaver project.
- Upload your
schema.json. - Review on-MCU inference time, memory footprint, and flash.
If the measured numbers don't fit your hardware budget, come back and shrink the model (fewer channels, fewer blocks, smaller kernels) and re-export.
Next steps¶
- SDK Reference: every public
dgclass and function. - Compile & Benchmark: get the model running on real MCUs.
- Supported Hardware: the boards Bitweaver targets today.