Created
March 5, 2023 13:20
-
-
Save Staars/67d849e70081b00e5c4d5bc3c302af39 to your computer and use it in GitHub Desktop.
Tasmota TFL Speech example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MICROSPEECH : Driver | |
var o # output tensor | |
var model # this var really holds the model data for the entire session | |
var out_buf, out_buf_idx | |
var average | |
var time_out | |
def init() | |
import TFL | |
var outputs = 5 # this value must be taken from the model | |
self.o = bytes(-outputs) # size must match the model -> 4 outputs | |
var descriptor = bytes(-10) | |
descriptor[0] = 4 # i2s_channel_fmt, 4=left | |
descriptor[1] = 16 # amplification factor | |
descriptor[2] = 32 # slice_dur in ms | |
descriptor[3] = 32 # slice_stride in ms | |
descriptor[4] = 26 # mfe filter (= features if MFE mode) | |
descriptor[5] = 13 # mfcc coefficients, if 0 -> compute MFE only | |
descriptor[6] = 9 # 2^9 = 512 fft_bins | |
descriptor[7] = 10 # max. invocations per second - find best value by testing on device | |
descriptor[8] = 52 # db noisefloor -> negative value | |
descriptor[9] = 0 # preemphasis | |
TFL.begin("MIC",descriptor) | |
self.model = open("mfcc.lite").readbytes() | |
if self.model | |
TFL.load(self.model,self.o,15000) | |
end | |
self.out_buf = bytes(-(4*outputs)) # 4 is a hard coded value for rolling average | |
self.out_buf_idx = 0 | |
self.average = bytes(-outputs) | |
self.time_out = 0 # after succesful finding in loops | |
end | |
def predict(data,outputs) | |
if self.time_out > 0 | |
self.time_out -= 1 | |
return -1 # nothing | |
end | |
for i:0..(outputs-1) | |
self.out_buf[self.out_buf_idx+i] = data[i] | |
end | |
self.out_buf_idx += outputs | |
if self.out_buf_idx>15 # (number of averages - 1) * outputs | |
self.out_buf_idx=0 | |
end | |
var max_val = 0 | |
var result = -1 | |
#print(self.out_buf,data) | |
for i:0..(outputs-1) | |
for j:0..(3*outputs) | |
self.average[i] = self.out_buf[i+j] | |
j += outputs | |
end | |
#self.average[i] = (self.out_buf[i] + self.out_buf[i+outputs] + self.out_buf[i+(2*outputs)] + self.out_buf[i+(3*outputs)])/4 | |
if self.average[i]>max_val | |
max_val = self.average[i] | |
if max_val>225 # find the threshold by trial and error | |
result = i | |
if i != 2 # this is the noise/unknown output - you must infer this from your model | |
self.out_buf = bytes(-outputs*4) # only clear the buffer after finding a keyword, not for noise/unknown | |
self.time_out = 10 # find the value by trial and error - good starting point is one second | |
tasmota.gc() | |
end | |
end | |
end | |
end | |
#print(self.average,self.out_buf) | |
return result | |
end | |
def every_50ms() | |
import TFL | |
if TFL.output(self.o) | |
var new = [self.o.geti(0,1)+128,self.o.geti(1,1)+128,self.o.geti(2,1)+128,self.o.geti(3,1)+128,self.o.geti(4,1)+128] | |
#print(new) | |
var r = self.predict(new,5) | |
if r == 0 | |
print("TFL: received 'down'") | |
end | |
if r == 1 | |
print("TFL: received 'left'") | |
end | |
# r == 2 -> noise/unknown | |
if r == 3 | |
print("TFL: received 'right'") | |
end | |
if r == 4 | |
print("TFL: received 'up'") | |
end | |
end | |
var s = TFL.log() | |
if s print(s) end | |
end | |
end | |
var mic = MICROSPEECH() | |
tasmota.add_driver(mic) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment