Created
May 27, 2025 05:59
-
-
Save hamees-sayed/44110f2c6cc47e2ee48adf9dc82bb6f1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| import time | |
| import krisp_audio | |
| from audio_utils import TestUtils | |
| from pydub import AudioSegment | |
| console_log_char_count = 50 | |
| class KrispAudioTest: | |
| def __init__(self, model_path, audio_stream_info, frame_dur, suppression_level): | |
| self.model_path = model_path | |
| self.audio_stream_info = audio_stream_info | |
| self.frame_dur = frame_dur | |
| self.nc_instance = None | |
| self.suppression_level = suppression_level | |
| krisp_audio.globalInit("") | |
| version_info = krisp_audio.getVersion() | |
| # print("=" * console_log_char_count) | |
| def __enter__(self): | |
| self.nc_instance = self._create_nc_session() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.nc_instance = None | |
| krisp_audio.globalDestroy() | |
| def _int_to_sample_rate(self, sample_rate): | |
| rates = { | |
| 8000: krisp_audio.SamplingRate.Sr8000Hz, | |
| 16000: krisp_audio.SamplingRate.Sr16000Hz, | |
| 24000: krisp_audio.SamplingRate.Sr24000Hz, | |
| 32000: krisp_audio.SamplingRate.Sr32000Hz, | |
| 44100: krisp_audio.SamplingRate.Sr44100Hz, | |
| 48000: krisp_audio.SamplingRate.Sr48000Hz | |
| } | |
| if sample_rate not in rates: | |
| raise ValueError("Unsupported sample rate") | |
| return rates[sample_rate] | |
| def _int_to_frame_dur(self, frame_dur): | |
| durations = { | |
| 10: krisp_audio.FrameDuration.Fd10ms, | |
| 15: krisp_audio.FrameDuration.Fd15ms, | |
| 20: krisp_audio.FrameDuration.Fd20ms, | |
| 30: krisp_audio.FrameDuration.Fd30ms, | |
| 32: krisp_audio.FrameDuration.Fd32ms | |
| } | |
| if frame_dur not in durations: | |
| raise ValueError("Unsupported frame duration") | |
| return durations[frame_dur] | |
| def _create_nc_session(self): | |
| model_info = krisp_audio.ModelInfo() | |
| model_info.path = self.model_path | |
| nc_cfg = krisp_audio.NcSessionConfig() | |
| nc_cfg.inputSampleRate = self._int_to_sample_rate(self.audio_stream_info['sample_rate']) | |
| nc_cfg.inputFrameDuration = self._int_to_frame_dur(self.frame_dur) | |
| nc_cfg.outputSampleRate = nc_cfg.inputSampleRate | |
| nc_cfg.modelInfo = model_info | |
| if self.audio_stream_info['sample_type'] == 'FLOAT': | |
| return krisp_audio.NcFloat.create(nc_cfg) | |
| elif self.audio_stream_info['sample_type'] == 'PCM_16': | |
| return krisp_audio.NcInt16.create(nc_cfg) | |
| else: | |
| raise ValueError(f"Unsupported sample type {self.audio_stream_info['sample_type']}") | |
| def process_stream(self): | |
| input_audio_stream = self.audio_stream_info['audio_stream'] | |
| total_frames = len(input_audio_stream) | |
| processed_audio_stream = [] | |
| for i, frame in enumerate(input_audio_stream): | |
| processed_frame = self.nc_instance.process(frame, self.suppression_level) | |
| processed_audio_stream.append(processed_frame) | |
| progress = (i + 1) / total_frames * 100 | |
| sys.stdout.write(f"\rProcessing frame {i + 1}/{total_frames} ({progress:.2f}%)") | |
| sys.stdout.flush() | |
| sys.stdout.write("\n") | |
| return processed_audio_stream | |
| # Keep all imports and classes as-is | |
| def run_krisp_processing(input_audio_path, output_audio_path, kef_path="/Users/hamees/Desktop/work/krisp-integration/server-models-v9.4/inb.bvc.hs.c6.w.s.23cdb3.kef", frame_dur=20, suppression_level=100): | |
| # print("=" * console_log_char_count) | |
| # print("Krisp Python SDK Sample App") | |
| # print("=" * console_log_char_count) | |
| if input_audio_path.lower().endswith('.mp3'): | |
| wav_temp_path = os.path.splitext(input_audio_path)[0] + "_temp.wav" | |
| audio = AudioSegment.from_mp3(input_audio_path) | |
| audio.export(wav_temp_path, format="wav") | |
| input_audio_path = wav_temp_path | |
| audio_stream_info = TestUtils.wav_to_audio_stream(input_audio_path, frame_dur) | |
| # print('Input Audio info') | |
| # print(f'Sample rate: {audio_stream_info["sample_rate"]}') | |
| # print(f'Sample type: {audio_stream_info["sample_type"]}') | |
| # print("=" * console_log_char_count) | |
| with KrispAudioTest(kef_path, audio_stream_info, frame_dur, suppression_level) as krisp: | |
| processed_audio_stream = krisp.process_stream() | |
| file_name_without_extension = os.path.splitext(os.path.basename(input_audio_path))[0] | |
| # output_audio_path = f"{file_name_without_extension}_nvc_{frame_dur}ms_{audio_stream_info['sample_rate']}hz.wav" | |
| TestUtils.audio_stream_to_wav(output_audio_path, processed_audio_stream, | |
| audio_stream_info['sample_rate'], audio_stream_info['sample_type']) | |
| # print(f'Processed audio written to {output_audio_path}') | |
| # print("=" * console_log_char_count) | |
| # print("Exiting Application") | |
| # print("=" * console_log_char_count) | |
| if 'wav_temp_path' in locals() and os.path.exists(wav_temp_path): | |
| os.remove(wav_temp_path) | |
| return output_audio_path | |
| # Optional: preserve original CLI behavior | |
| if __name__ == "__main__": | |
| run_krisp_processing(input_audio_path, output_audio_path, kef_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import soundfile as sf | |
| class TestUtils: | |
| @staticmethod | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser(description="Sample usage of krisp_audio") | |
| parser.add_argument('-cfg', '--kef_path', type=str, required=True, help='Path to the .kef file') | |
| parser.add_argument('-i', '--input_audio_path', type=str, required=True, help='Path to the input .wav audio file') | |
| parser.add_argument('-d', '--frame_dur', type=int, default=10, help='Optional: Frame duration. Supported: 10, 15, 20, 30, 32. Default: 10') | |
| parser.add_argument('-nsl', '--suppression_level', type=int, default=100, help='Optional: Noise suppression level in the range [0, 100], used to controll the noise canceling aggressiveness. Default: 100, i.e. full noise canceling') | |
| return parser.parse_args() | |
| @staticmethod | |
| def wav_to_audio_stream(wav_path, frame_dur): | |
| with sf.SoundFile(wav_path) as inputFile: | |
| sample_rate = inputFile.samplerate | |
| sample_type = inputFile.subtype | |
| if sample_type == 'PCM_16': | |
| data_type = 'int16' | |
| elif sample_type == 'FLOAT': | |
| data_type = 'float32' | |
| else: | |
| raise ValueError(f"Unsupported WAV data type: {sample_type}") | |
| audio_data = inputFile.read(dtype=data_type) | |
| if inputFile.channels > 1: | |
| raise ValueError(f"Supports only Mono audio, provided audio channels: {inputFile.channels}") | |
| frame_size = int(sample_rate * frame_dur / 1000) | |
| audio_stream = [audio_data[i:i + frame_size] for i in range(0, len(audio_data) - frame_size + 1, frame_size)] | |
| return { | |
| 'sample_rate': sample_rate, | |
| 'sample_type': sample_type, | |
| 'audio_stream': audio_stream | |
| } | |
| @staticmethod | |
| def audio_stream_to_wav(wav_path, audio_stream, sample_rate, sample_type): | |
| concat_stream = [sample for frame in audio_stream for sample in frame] | |
| with sf.SoundFile(wav_path, mode='w', samplerate=sample_rate, channels=1, subtype=sample_type) as wav_file: | |
| wav_file.write(concat_stream) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment