Last active
April 26, 2024 18:13
-
-
Save 0187773933/0081d6303be18360b88e929e3f34877d to your computer and use it in GitHub Desktop.
Runs Google Media Pipe Yamnet Audio Classification on Microphone Audio
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sounddevice as sd | |
import numpy as np | |
import tensorflow as tf | |
import queue | |
from collections import defaultdict , deque | |
import time | |
# https://storage.googleapis.com/mediapipe-models/audio_classifier/yamnet/float32/latest/yamnet.tflite | |
# https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/yamnet.py | |
# https://research.google.com/audioset/ontology/index.html | |
# https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_label_list.txt | |
# https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/params.py#L25 | |
MODEL_PATH = "./yamnet.tflite" | |
LABEL_PATH = "./yamnet_label_list.txt" | |
SAMPLE_RATE = 16000 | |
PATCH_WINDOW_SECONDS = 0.975 | |
PATCH_HOP_SECONDS = ( PATCH_WINDOW_SECONDS / 2.0 ) | |
WATCH_WINDOW_SECONDS = 30 # Time window to aggregate results | |
PRINT_WINDOW_TOTAL = 10 | |
MINIMUM_THRESHOLD = 0.15 | |
interpreter = tf.lite.Interpreter( model_path=MODEL_PATH ) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
q = queue.Queue() | |
def audio_callback( indata , frames , time , status ): | |
if status: | |
print( f"Error: {status}" ) | |
q.put( indata.copy() ) | |
def read_text( file_path ): | |
with open( file_path ) as f: | |
return f.read().splitlines() | |
def calculate_db( indata ): | |
rms = np.sqrt( np.mean( indata**2 ) ) | |
# Avoid log of zero by adding a small value | |
rms = max( rms , 1e-10 ) | |
# Convert to dB | |
db = 20 * np.log10( rms ) | |
return db | |
def main(): | |
model_labels = read_text(LABEL_PATH) | |
results = defaultdict(float) | |
past_results = deque() | |
last_time = time.time() | |
try: | |
with sd.InputStream(callback=audio_callback, dtype="int16", channels=2, samplerate=SAMPLE_RATE, blocksize=int(SAMPLE_RATE * PATCH_HOP_SECONDS)): | |
print("Starting audio stream...") | |
while True: | |
current_time = time.time() | |
data = np.concatenate([q.get() for _ in range(int(SAMPLE_RATE * PATCH_WINDOW_SECONDS / (SAMPLE_RATE * PATCH_HOP_SECONDS)))]) | |
data = np.mean(data.astype(np.float32), axis=1) / np.iinfo(np.int16).max # Normalize and convert to mono | |
interpreter.set_tensor(input_details[0]["index"], data) | |
interpreter.invoke() | |
probabilities = interpreter.get_tensor(output_details[0]["index"]).flatten() | |
# Store results with timestamp | |
past_results.append((current_time, dict(zip(model_labels, probabilities)))) | |
# Remove results older than WATCH_WINDOW_SECONDS | |
while past_results and past_results[0][0] < (current_time - WATCH_WINDOW_SECONDS): | |
old_time, old_results = past_results.popleft() | |
for label in old_results: | |
results[label] -= old_results[label] | |
# Add new results | |
for label, probability in zip(model_labels, probabilities): | |
results[label] += probability | |
# Apply minimum threshold filter only for display | |
filtered_results = {label: prob for label, prob in results.items() if prob > MINIMUM_THRESHOLD} | |
# Sort and display filtered results | |
sorted_results = sorted(filtered_results.items(), key=lambda item: item[1], reverse=True) | |
db_level = calculate_db(data) # Extras - DB Level | |
# Print updated results | |
print(f"\nLIVE : DB === {db_level}") | |
if sorted_results: | |
print(f"LIVE : TOP === {sorted_results[0][0]} : {sorted_results[0][1]}") | |
print(f"Last {WATCH_WINDOW_SECONDS} Seconds:") | |
for label, probability in sorted_results[:PRINT_WINDOW_TOTAL]: | |
print(f"\t{label}: {probability}") | |
except KeyboardInterrupt: | |
print("\nStopping...") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment