Used in https://www.kaggle.com/models/shadiakiki1/birdnet-analyzer
Last active
April 25, 2024 00:45
-
-
Save shadiakiki1986/76bf09894c6dc7a5cea8a9614069181a to your computer and use it in GitHub Desktop.
wrapper python class for https://github.com/kahst/BirdNET-Analyzer/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Forked function from BirdNET-Analyzer/main.py function load_model and predict | |
# https://github.com/kahst/BirdNET-Analyzer/blob/main/config.py | |
# https://github.com/kahst/BirdNET-Analyzer/blob/main/model.py | |
# https://github.com/kahst/BirdNET-Analyzer/blob/main/analyze.py | |
try: | |
import tflite_runtime.interpreter as tflite | |
except ModuleNotFoundError: | |
from tensorflow import lite as tflite | |
from multiprocessing import cpu_count | |
import pandas as pd | |
import tensorflow as tf | |
import numpy as np | |
from pathlib import Path | |
BASEDIR = ( | |
Path(__file__).parent | |
#Path("/kaggle/input/birdnet-analyzer/tflite/birdnet_global_6k_v2.4_model_fp32-1/2/") | |
) | |
# class ModelBirdnetAnalyzer: | |
class Model: | |
def __init__(self, class_output=True): | |
# Load TFLite model and allocate tensors. | |
# model_path from: https://github.com/kahst/BirdNET-Analyzer/blob/main/config.py#L18C20-L18C77 | |
INTERPRETER = tflite.Interpreter(model_path= | |
str( | |
#Path("BirdNET-Analyzer/checkpoints/V2.4") | |
BASEDIR | |
/"BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite" | |
), num_threads=cpu_count()) | |
INTERPRETER.allocate_tensors() | |
# Get input and output tensors. | |
input_details = INTERPRETER.get_input_details() | |
output_details = INTERPRETER.get_output_details() | |
# Get input tensor index | |
INPUT_LAYER_INDEX = input_details[0]["index"] | |
# Get classification output or feature embeddings | |
if class_output: | |
OUTPUT_LAYER_INDEX = output_details[0]["index"] | |
else: | |
OUTPUT_LAYER_INDEX = output_details[0]["index"] - 1 | |
self.INTERPRETER = INTERPRETER | |
self.INPUT_LAYER_INDEX = INPUT_LAYER_INDEX | |
self.OUTPUT_LAYER_INDEX = OUTPUT_LAYER_INDEX | |
with open( | |
#Path("BirdNET-Analyzer/checkpoints/V2.4") | |
BASEDIR | |
/"BirdNET_GLOBAL_6K_V2.4_Labels.txt", "r") as f: | |
self.labels = [x.strip() for x in f.readlines()] | |
def predict(self, data, samplerate, strict=True): | |
if strict: | |
assert samplerate==48_000, "birdnet assumes 48kHz?" | |
assert len(data.shape)==2, "data should be 2d array" | |
assert data.shape[1]==samplerate*3, "data should be 3 second windows" | |
# Reshape input tensor | |
self.INTERPRETER.resize_tensor_input(self.INPUT_LAYER_INDEX, [len(data), *data[0].shape]) | |
self.INTERPRETER.allocate_tensors() | |
# Make a prediction (Audio only for now) | |
self.INTERPRETER.set_tensor(self.INPUT_LAYER_INDEX, np.array(data, dtype="float32")) | |
self.INTERPRETER.invoke() | |
prediction = self.INTERPRETER.get_tensor(self.OUTPUT_LAYER_INDEX) | |
prediction = pd.DataFrame(prediction, columns=self.labels) | |
return prediction | |
def predict_proba(self, *args, **kwargs): | |
prediction = self.predict(*args, **kwargs) | |
return ( | |
prediction | |
# logits to probabilities | |
.pipe(lambda df: pd.DataFrame(tf.nn.softmax(df), index=df.index, columns=df.columns)) | |
#.sum(axis=1) # all 1. | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment