Last active
September 19, 2022 22:24
-
-
Save thekalinga/bb34c82e8883248d508623a589c50744 to your computer and use it in GitHub Desktop.
Spring Reactor Netty reactive(non-blocking) websocket integration that converts a text file into audio. Splits file into chunks, over websocket (text & binary frames) downloads multiple binary audio fragments for each chunk, merges all fragments into a chunk file, finally merges all chunk files into final mp3 from Azure without blocking threads
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.acme.tts.converter.microsoft.ssml.websocket; | |
import com.acme.tts.converter.microsoft.ssml.Prosody; | |
import com.acme.tts.converter.microsoft.ssml.Speak; | |
import com.acme.tts.converter.microsoft.ssml.Voice; | |
import com.fasterxml.jackson.databind.ObjectMapper; | |
import com.fasterxml.jackson.databind.SerializationFeature; | |
import com.fasterxml.jackson.dataformat.xml.XmlMapper; | |
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; | |
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; | |
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; | |
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; | |
import lombok.extern.log4j.Log4j2; | |
import reactor.core.publisher.Flux; | |
import reactor.core.publisher.Mono; | |
import reactor.core.publisher.Sinks; | |
import reactor.netty.http.client.HttpClient; | |
import reactor.netty.http.client.WebsocketClientSpec; | |
import java.nio.file.Path; | |
import java.time.Duration; | |
import java.time.Instant; | |
import java.time.ZoneId; | |
import java.time.ZonedDateTime; | |
import java.time.format.DateTimeFormatter; | |
import java.util.List; | |
import java.util.UUID; | |
import java.util.concurrent.atomic.AtomicBoolean; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.concurrent.atomic.AtomicReference; | |
import java.util.regex.Matcher; | |
import java.util.regex.Pattern; | |
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE; | |
import static io.earcam.unexceptional.Exceptional.get; | |
import static java.nio.charset.StandardCharsets.UTF_8; | |
import static java.time.Duration.ofSeconds; | |
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST; | |
@Log4j2 | |
class ChunksAudioFetcherAndPersister { | |
public static final String WEBSOCKET_HTTPS_URL_PREFIX = | |
"https://eastus.api.speech.microsoft.com/cognitiveservices/websocket/v1?TrafficType=AzureDemo&Authorization=bearer%20undefined&X-ConnectionId="; | |
private static final String CRLF = "\r\n"; | |
private static final String DOUBLE_CRLF = CRLF + CRLF; | |
private static final Duration PING_MESSAGE_WEBSOCKET_DURATION = ofSeconds(10); | |
private static final String TURN_START_HEADER_VALUE = "turn\\.start"; | |
private static final String X_RATE_LIMIT_HEADER_MAX_ALLOWED = "X-RateLimit-Limit"; | |
private static final String X_RATE_LIMIT_HEADER_REMAINING = "X-RateLimit-Remaining"; | |
private static final String X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT = "X-RateLimit-Reset"; | |
private static final String X_REQUEST_HEADER_NAME = "X-RequestId"; | |
private static final Pattern X_REQUEST_ID_PATTERN = | |
Pattern.compile(X_REQUEST_HEADER_NAME + ":(?<requestId>[^\r]+)"); | |
private static final String PATH_HEADER = "Path"; | |
public static final Pattern | |
PATH_TURN_END_HEADER_PATTERN = Pattern.compile(PATH_HEADER + ":\s*turn\\.end" + CRLF); | |
private static final Pattern PATH_TURN_START_HEADER_PATTERN = | |
Pattern.compile(PATH_HEADER + ":\s*" + TURN_START_HEADER_VALUE + CRLF); | |
private final HttpClient client; | |
private final WebsocketClientSpec websocketClientSpec; | |
private final List<String> remainingChunks; | |
private final Sinks.Many<Integer> retryFromTextChunkIndexOnAbruptCompleteSink; | |
private final Path dirToDownloadTo; | |
private final int startingChunkNumberInOriginalChunks; | |
private final int totalChunkCountOfOriginalChunks; | |
private final Path inputFilePath; | |
private final ObjectMapper objectMapper = new XmlMapper(); | |
{ | |
objectMapper.enable(SerializationFeature.INDENT_OUTPUT); | |
} | |
ChunksAudioFetcherAndPersister(HttpClient client, WebsocketClientSpec websocketClientSpec, | |
List<String> remainingChunks, Sinks.Many<Integer> retryFromTextChunkIndexOnAbruptCompleteSink, | |
Path dirToDownloadTo, int startingChunkNumberInOriginalChunks, int totalChunkCountOfOriginalChunks, Path inputFilePath) { | |
this.client = client; | |
this.websocketClientSpec = websocketClientSpec; | |
this.remainingChunks = remainingChunks; | |
this.retryFromTextChunkIndexOnAbruptCompleteSink = retryFromTextChunkIndexOnAbruptCompleteSink; | |
this.dirToDownloadTo = dirToDownloadTo; | |
this.startingChunkNumberInOriginalChunks = startingChunkNumberInOriginalChunks; | |
this.totalChunkCountOfOriginalChunks = totalChunkCountOfOriginalChunks; | |
this.inputFilePath = inputFilePath; | |
} | |
Flux<Void> retrieve() { | |
final var audioChunksPersistingSink = Sinks.many().unicast().<Sinks.Many<byte[]>>onBackpressureBuffer(); | |
final var audioChunkDownloadCompleteInfiniteStreamSink = Sinks.many().multicast().<Marker>onBackpressureBuffer(); | |
final var inputRequestStreamSignalSink = Sinks.many().unicast().onBackpressureBuffer(); // this is emits one item per input item followed by onComplete | |
final var errorNotifyingSink = Sinks.one(); | |
final var shortCircuitCompletionSink = Sinks.one(); | |
final var allChunksDownloadComplete$ = inputRequestStreamSignalSink.asFlux() | |
.zipWith(audioChunkDownloadCompleteInfiniteStreamSink.asFlux()).cast(Object.class) // since we are zipping both input text & output mp3 generation, this will complete only when output has produced as many items as the input | |
.then(Mono.just(INSTANCE).cast(Object.class)) // once both of above are done, lets emit an item. | |
.mergeWith(errorNotifyingSink.asMono()) | |
.takeUntilOther(shortCircuitCompletionSink.asMono()); | |
final var chunkPersister$ = new ChunksPersister( | |
audioChunksPersistingSink, audioChunkDownloadCompleteInfiniteStreamSink, | |
dirToDownloadTo, startingChunkNumberInOriginalChunks, totalChunkCountOfOriginalChunks, | |
inputFilePath) | |
.asMono(); | |
final var wsHttpsUri = WEBSOCKET_HTTPS_URL_PREFIX + generateRequestId(); | |
//noinspection UnnecessaryLocalVariable | |
final var webSocketInBoundOutboundMerged$ = client.websocket(websocketClientSpec) | |
.uri(wsHttpsUri) | |
.handle((inbound, outbound) -> { | |
final var inboundHeaders = inbound.headers(); | |
boolean rateLimitHeadersExist = inboundHeaders.contains(X_RATE_LIMIT_HEADER_MAX_ALLOWED) && | |
inboundHeaders.contains(X_RATE_LIMIT_HEADER_REMAINING) && | |
inboundHeaders.contains(X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT); | |
if(rateLimitHeadersExist) { | |
final var maxAllowedRequests = Integer.parseInt(inboundHeaders.get( | |
X_RATE_LIMIT_HEADER_MAX_ALLOWED)); | |
final var remainingAllowedRequests = Integer.parseInt(inboundHeaders.get(X_RATE_LIMIT_HEADER_REMAINING)); | |
final var limitResetsAtInGmt = DateTimeFormatter.ISO_DATE_TIME.parse(inboundHeaders.get(X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT), ZonedDateTime::from); | |
log.trace("Rate limiting from server: maxAllowedRequests: {}, remainingAllowedRequests: {}, limitResetsAt: {}", maxAllowedRequests, remainingAllowedRequests, limitResetsAtInGmt.withZoneSameLocal( | |
ZoneId.of(ZoneId.SHORT_IDS.get("IST")))); | |
if (remainingAllowedRequests < 0) { | |
final var nowInGmt = ZonedDateTime.now(ZoneId.of(ZoneId.SHORT_IDS.get("GMT"))); | |
final var waitPeriod = Duration.between(nowInGmt, limitResetsAtInGmt); | |
return Mono.delay(waitPeriod) | |
.doOnSubscribe(__ -> log.debug("Server is rate limiting us: Next attempt would be made in {}", waitPeriod)) | |
.flatMapMany(__ -> retrieve()); | |
} | |
} | |
final var downloadCompleteChunkIndex = new AtomicInteger(-1); | |
final var requestId = new AtomicReference<String>(null); | |
final var audioStreamCollectingSink = new AtomicReference<Sinks.Many<byte[]>>(null); | |
final var turnStarted = new AtomicBoolean(false); | |
final var turnEnded = new AtomicBoolean(false); | |
final var inbound$ = inbound | |
.receiveFrames() | |
// .log("com.acme.inbound-signals") | |
// .doOnCancel(() -> log.trace("doOnCancel: We cancelled")) | |
// .doOnComplete(() -> log.trace("doOnComplete: Framework completed - reached end of session / otherside sent completion")) | |
// .doOnError(e -> System.err.println("doOnError: Framework erred - n/w error / otherside erred out / framework level issue")) | |
.handle((wsf, downstreamSink) -> { | |
log.trace("Received frame of type " + wsf.getClass()); | |
if (wsf instanceof TextWebSocketFrame textWebSocketFrame) { | |
final var frameContent = textWebSocketFrame.text(); | |
// log.trace("↓↓↓↓↓↓↓↓ textWebSocketFrame.text() ↓↓↓↓↓↓↓↓\n" + frameContent); | |
final var bodyStartIndex = frameContent.indexOf(DOUBLE_CRLF) + DOUBLE_CRLF.length(); | |
final var headerSection = frameContent.substring(0, bodyStartIndex); | |
log.trace("↓↓↓↓↓↓↓ header section ↓↓↓↓↓↓\n" + headerSection); | |
boolean turnStartFrame = PATH_TURN_START_HEADER_PATTERN.matcher(headerSection).find(); | |
log.trace("Did turn start? " + turnStartFrame); | |
if (!turnStarted.get() && turnStartFrame) { | |
Matcher requestIdMatcher = X_REQUEST_ID_PATTERN.matcher(headerSection); | |
if (requestIdMatcher.find()) { | |
final var matchedRequestId = requestIdMatcher.group("requestId"); | |
requestId.set(matchedRequestId); | |
} else { | |
// lets fail application | |
log.debug("Text segment(s) that caused the error are {} \n", remainingChunks); | |
errorNotifyingSink.tryEmitError(new RuntimeException( | |
"Received turn.start TextWebsocketFrame. But it did not contain " + X_REQUEST_HEADER_NAME | |
+ " header")); | |
} | |
final var newAudioStreamCollectingSink = Sinks.many().unicast().<byte[]>onBackpressureBuffer(); | |
audioStreamCollectingSink.set(newAudioStreamCollectingSink); | |
audioChunksPersistingSink.emitNext(newAudioStreamCollectingSink, FAIL_FAST); | |
turnEnded.set(false); | |
turnStarted.set(true); | |
} | |
boolean turnEndFrame = PATH_TURN_END_HEADER_PATTERN.matcher(headerSection).find(); | |
log.trace("Did turn end? " + turnEndFrame); | |
if (!turnEnded.get() && turnEndFrame) { | |
audioStreamCollectingSink.get().emitComplete(FAIL_FAST); | |
requestId.set(null); | |
downloadCompleteChunkIndex.incrementAndGet(); | |
turnStarted.set(false); | |
turnEnded.set(true); | |
if (downloadCompleteChunkIndex.get() + 1 == remainingChunks.size()) { | |
log.info("Download of current chunks is complete."); | |
shortCircuitCompletionSink.emitValue(INSTANCE, FAIL_FAST); | |
retryFromTextChunkIndexOnAbruptCompleteSink.emitComplete(FAIL_FAST); | |
} | |
} | |
final var frameJsonBody = frameContent.substring(bodyStartIndex); | |
log.trace("↓↓↓↓↓↓↓ json body ↓↓↓↓↓↓\n" + frameJsonBody); | |
} else if (wsf instanceof BinaryWebSocketFrame binaryWebSocketFrame) { | |
final var webSocketContent = binaryWebSocketFrame.content(); | |
final short headerLengthBytes = webSocketContent.readShort(); | |
log.trace("headerLengthBytes = " + headerLengthBytes); | |
final var headerBytes = new byte[headerLengthBytes]; | |
webSocketContent.readBytes(headerBytes); | |
final var headerSection = new String(headerBytes, UTF_8); | |
log.trace("↓↓↓↓↓↓↓ header section ↓↓↓↓↓↓\n" + headerSection); | |
boolean partIsAudio = Pattern.compile(PATH_HEADER + ":\s*audio" + CRLF).matcher(headerSection).find(); | |
boolean requestIdMatches = Pattern.compile(X_REQUEST_HEADER_NAME + ":" + requestId.get() + CRLF).matcher(headerSection).find(); | |
if (requestIdMatches) { | |
if (partIsAudio) { | |
byte[] bytesRead = new byte[webSocketContent.writerIndex() - webSocketContent.readerIndex()]; | |
webSocketContent.readBytes(bytesRead); | |
audioStreamCollectingSink.get().emitNext(bytesRead, FAIL_FAST); | |
} | |
} else { | |
// lets fail application | |
log.debug("Text segment(s) that caused the error are {} \n", remainingChunks); | |
errorNotifyingSink.tryEmitError(new RuntimeException( | |
"Received audio in a BinaryWebsocketFrame. But it did not match " + X_REQUEST_HEADER_NAME | |
+ " header")); | |
} | |
} else if (wsf instanceof CloseWebSocketFrame closeWebSocketFrame) { | |
log.debug("Received close websocket frame (might/might not be an issue). statusCode: {}, reason: {}", closeWebSocketFrame.statusCode(), closeWebSocketFrame.reasonText()); | |
// log.debug("Text segment(s) that caused the error are {} \n", textChunks); | |
// errorNotifyingSink.tryEmitError(new RuntimeException("Otherside asked us to close the stream. statusCode: " + closeWebSocketFrame.statusCode() + " reason: " + closeWebSocketFrame.reasonText())); | |
} // we dont care about Ping & Pong as framework takes care of that automatically when we set `handlePing(false)` which is the default | |
}) | |
.doOnComplete(() -> { | |
// at times server is sending fin (finish aka orderly closure of connection from its end) about 15seconds into websocket conversation in the middle websocket frame transfer & reactor netty client is only sendingg ack (acknowledgement of TCP bytes received) the reception of server fin, but not sending a fin (orderly closure) & terminating the application while the underlying TCP connection is open at OS level. | |
// so server is sending a rst (reset aka force closure of dangling tcp connection) after about 120 seconds | |
// in this case we need to restart the download for current frame | |
if (!turnEnded.get()) { | |
log.info("Received an unexpected complete (FIN) from server while downloading chunk #{} (index = {}) in current session", (downloadCompleteChunkIndex.get() + 2), (downloadCompleteChunkIndex.get() + 1)); | |
retryFromTextChunkIndexOnAbruptCompleteSink.emitNext((downloadCompleteChunkIndex.get() + 1), FAIL_FAST); | |
} else { | |
log.info("Received an unexpected complete (FIN) from server."); | |
shortCircuitCompletionSink.emitValue(INSTANCE, FAIL_FAST); | |
retryFromTextChunkIndexOnAbruptCompleteSink.emitComplete(FAIL_FAST); | |
} | |
}) | |
.then(); | |
final var requestIdAndTextWebSocketFrameForText$$ = Flux.fromIterable(remainingChunks).map(this::requestIdAndTextWebSocketFrameForText$); | |
final Flux<TextWebSocketFrame> twf$ = requestIdAndTextWebSocketFrameForText$$ | |
.zipWith(Mono.just(INSTANCE).cast(Object.class).concatWith(audioChunkDownloadCompleteInfiniteStreamSink.asFlux())) // lets emit next batch only when previous chunk is fully downloaded | |
.doOnNext(__ -> inputRequestStreamSignalSink.emitNext(INSTANCE, FAIL_FAST)) | |
.concatMap(t -> Flux.from(t.getT1()).map(RequestIdAndTextWebSocketFrame::frame)) | |
.doOnComplete(() -> inputRequestStreamSignalSink.emitComplete(FAIL_FAST)); | |
final var pingWsf$ = Flux.interval(PING_MESSAGE_WEBSOCKET_DURATION).map(__ -> new PingWebSocketFrame()); // infinite stream | |
final var outboundWs$ = initTextWebSocketFrameToSetSpeechConfigForWebsocketSession$() | |
.concatWith(twf$).cast(Object.class) | |
.mergeWith(pingWsf$); | |
final var outbound$ = outbound.sendObject(outboundWs$, __ -> true); | |
final var closeStatus$ = inbound.receiveCloseStatus() | |
.doOnNext(status -> { | |
log.debug("Received close status. statusCode: {}; reason: {}", status.code(), status.reasonText()); | |
// log.debug("Text segment(s) that caused the error are {} \n", textChunks); | |
// errorNotifyingSink.tryEmitError(new RuntimeException("Otherside asked us to close the stream. statusCode: " + status.code() + " reason: " + status.reasonText())); | |
}) | |
.then(); | |
return Mono.when(inbound$, outbound$, chunkPersister$, closeStatus$) | |
.takeUntilOther(allChunksDownloadComplete$); // lets cancel all streams once all chunks are downloaded | |
}); | |
return webSocketInBoundOutboundMerged$; | |
} | |
Mono<TextWebSocketFrame> initTextWebSocketFrameToSetSpeechConfigForWebsocketSession$() { | |
final var initialRequestBodyTemplate = """ | |
Path: speech.config | |
X-Timestamp: =timeStamp= | |
Content-Type: application/json; charset=utf-8 | |
{"context":{"system":{"name":"SpeechSDK","version":"1.19.0","build":"JavaScript","lang":"JavaScript"},"os":{"platform":"Browser/Linux x86_64","name":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36","version":"5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36"}}} | |
"""; | |
return Mono.just(initialRequestBodyTemplate) | |
.map(template -> { | |
final var timeStampHeaderValue = Instant.ofEpochMilli(System.currentTimeMillis()).atZone(ZoneId.of("GMT")).format(DateTimeFormatter.ISO_DATE_TIME).replaceAll("\\[[^]]+]", ""); | |
return template.strip().trim() // get rid of indentation | |
.replaceAll("=timeStamp=", timeStampHeaderValue) | |
.replaceAll("\n", "\r\n"); // microsoft server/spec expects the lines to be separated by \r\n instead of \n | |
}) | |
.map(TextWebSocketFrame::new); | |
} | |
private Flux<RequestIdAndTextWebSocketFrame> requestIdAndTextWebSocketFrameForText$(String requestId, String textToConvert) { | |
// we need this instead of string concatenation because we have to escape special characters within xml properly | |
Speak | |
speakXmlObj = new Speak("1.0", "en-US", new Voice("en-US", "Male", "en-GB-RyanNeural", new Prosody("0%", "0%", textToConvert))); | |
String ssmlSpeakXmlAsStr = get(() -> objectMapper.writeValueAsString(speakXmlObj)); | |
// we need `synthesis.context` twf | |
final var ttsTextWebsocketFrameBodyTemplates = List.of( | |
""" | |
Path: synthesis.context | |
X-RequestId: =requestId= | |
X-Timestamp: =timeStamp= | |
Content-Type: application/json; charset=utf-8 | |
{"synthesis":{"audio":{"metadataOptions":{"bookmarkEnabled":false,"sentenceBoundaryEnabled":false,"visemeEnabled":false,"wordBoundaryEnabled":false},"outputFormat":"audio-16khz-64kbitrate-mono-mp3"},"language":{"autoDetection":false}}} | |
""", | |
""" | |
Path: ssml | |
X-RequestId: =requestId= | |
X-Timestamp: =timeStamp= | |
Content-Type: application/ssml+xml | |
"""+ssmlSpeakXmlAsStr); | |
return Flux.fromIterable(ttsTextWebsocketFrameBodyTemplates) | |
.map(ttsTextWebsocketFrameBodyTemplate -> { | |
final var timeStampHeaderValue = generateTimeStamp(); | |
return ttsTextWebsocketFrameBodyTemplate.strip().trim() // get rid of indentation | |
.replaceAll("=requestId=", requestId) | |
.replaceAll("=timeStamp=", timeStampHeaderValue) | |
.replaceAll("=ssmlSpeakXml=", ssmlSpeakXmlAsStr) | |
.replaceAll("\n", "\r\n"); // microsoft server/spec expects the lines to be separated by \r\n instead of \n | |
}) | |
.map(frameBody -> new RequestIdAndTextWebSocketFrame(requestId, new TextWebSocketFrame(frameBody))); | |
} | |
private Flux<RequestIdAndTextWebSocketFrame> requestIdAndTextWebSocketFrameForText$(String textToConvert) { | |
final var requestId = generateRequestId(); | |
return requestIdAndTextWebSocketFrameForText$(requestId, textToConvert); | |
} | |
private static String generateTimeStamp() { | |
return Instant.ofEpochMilli(System.currentTimeMillis()).atZone(ZoneId.of("GMT")) | |
.format(DateTimeFormatter.ISO_DATE_TIME).replaceAll("\\[[^]]+]", ""); | |
} | |
private static String generateRequestId() { | |
return UUID.randomUUID().toString().replaceAll("-", "").toUpperCase(); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.acme.tts.converter.microsoft.ssml.websocket; | |
import lombok.extern.log4j.Log4j2; | |
import org.springframework.core.io.buffer.DataBuffer; | |
import org.springframework.core.io.buffer.DataBufferUtils; | |
import org.springframework.core.io.buffer.DefaultDataBufferFactory; | |
import reactor.core.publisher.Flux; | |
import reactor.core.publisher.Mono; | |
import reactor.core.publisher.Sinks.Many; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import static com.acme.tts.converter.ConverterUtil.getPartFileName; | |
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE; | |
import static io.earcam.unexceptional.Exceptional.run; | |
import static java.nio.file.StandardOpenOption.CREATE_NEW; | |
import static java.nio.file.StandardOpenOption.WRITE; | |
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST; | |
@Log4j2 | |
class ChunksPersister { | |
private final Many<Many<byte[]>> audioChunksPersistingSink; | |
private final Many<Marker> audioChunkDownloadCompleteInfiniteStreamSink; | |
private final Path dirToDownloadTo; | |
private final AtomicInteger nextChunkNumber; | |
private final int totalChunkCount; | |
private final Path inputFilePath; | |
public ChunksPersister(Many<Many<byte[]>> audioChunksPersistingSink, | |
Many<Marker> audioChunkDownloadCompleteInfiniteStreamSink, | |
Path dirToDownloadTo, | |
int startingChunkNumber, | |
int totalChunkCount, | |
Path inputFilePath) { | |
this.audioChunksPersistingSink = audioChunksPersistingSink; | |
this.audioChunkDownloadCompleteInfiniteStreamSink = | |
audioChunkDownloadCompleteInfiniteStreamSink; | |
this.dirToDownloadTo = dirToDownloadTo; | |
this. nextChunkNumber = new AtomicInteger(startingChunkNumber); | |
this.totalChunkCount = totalChunkCount; | |
this.inputFilePath = inputFilePath; | |
} | |
Mono<Void> asMono() { | |
return audioChunksPersistingSink.asFlux().switchMap(audioStreamCollectingSink -> { | |
final Flux<DataBuffer> audioPart$ = audioStreamCollectingSink.asFlux().cast(byte[].class) | |
.map(DefaultDataBufferFactory.sharedInstance::wrap); | |
int currChunkNumber = nextChunkNumber.getAndIncrement(); | |
final var chunkDestinationFileName = getPartFileName(currChunkNumber, totalChunkCount); | |
final var destination = dirToDownloadTo.resolve(chunkDestinationFileName); | |
// lets the file if it already exists | |
Mono<?> fileCleaner$ = Mono.empty(); | |
if (Files.exists(destination)) { | |
fileCleaner$ = Mono.create(sink -> { | |
run(() -> Files.deleteIfExists(destination)); | |
sink.success(); | |
}); | |
} | |
log.debug("[{}]: (part {} of {}) Downloading to {} …", inputFilePath.getFileName(), | |
currChunkNumber, totalChunkCount, chunkDestinationFileName); | |
return fileCleaner$.then( | |
DataBufferUtils.write(audioPart$, destination, CREATE_NEW, WRITE) | |
.doOnSuccess(__ -> { | |
audioChunkDownloadCompleteInfiniteStreamSink.tryEmitNext(INSTANCE); | |
log.debug("[{}]: (part {} of {}) Downloaded to {}", | |
inputFilePath.getFileName(), currChunkNumber, | |
totalChunkCount, chunkDestinationFileName); | |
}) | |
.doOnError(e -> audioChunkDownloadCompleteInfiniteStreamSink.emitError(e, FAIL_FAST))); | |
}) | |
.then(); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.acme.tts.converter.microsoft.ssml; | |
import com.acme.tts.converter.FileToTtsConverter; | |
import com.acme.tts.converter.microsoft.ssml.websocket.WebsocketBasedDownloader; | |
import com.fasterxml.jackson.databind.ObjectMapper; | |
import com.fasterxml.jackson.databind.SerializationFeature; | |
import com.fasterxml.jackson.dataformat.xml.XmlMapper; | |
import io.netty.buffer.ByteBufAllocator; | |
import lombok.extern.log4j.Log4j2; | |
import org.springframework.core.io.buffer.DataBufferUtils; | |
import org.springframework.core.io.buffer.NettyDataBufferFactory; | |
import reactor.core.publisher.Mono; | |
import reactor.core.scheduler.Schedulers; | |
import java.io.IOException; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.util.List; | |
import java.util.stream.IntStream; | |
import static com.acme.tts.common.Util.launchProcessMergeInputAndOutputStreamsIgnoreInputStream; | |
import static com.acme.tts.converter.ConverterUtil.getPartFileName; | |
import static com.acme.tts.converter.ConverterUtil.readAndSanitiseText; | |
import static com.acme.tts.converter.ConverterUtil.splitTextIntoChunks; | |
import static io.earcam.unexceptional.Exceptional.get; | |
import static io.earcam.unexceptional.Exceptional.run; | |
import static java.lang.String.join; | |
import static java.nio.file.Files.createDirectory; | |
import static java.nio.file.Files.exists; | |
import static java.nio.file.Files.move; | |
import static java.nio.file.StandardOpenOption.CREATE_NEW; | |
import static java.nio.file.StandardOpenOption.WRITE; | |
import static java.util.List.of; | |
import static java.util.stream.Collectors.joining; | |
import static org.springframework.util.FileSystemUtils.deleteRecursively; | |
import static org.springframework.util.StringUtils.stripFilenameExtension; | |
/** | |
* Utility created to make the same calls <a href="https://azure.microsoft.com/en-us/services/cognitive-services/text-to-speech/">this page</a> makes. It uses websockets, so we have to use them too | |
*/ | |
@Log4j2 | |
public class MicrosoftTtsDownloader implements FileToTtsConverter { | |
private final WebsocketBasedDownloader websocketBasedDownloader; | |
public MicrosoftTtsDownloader(boolean debugEnabled) { | |
websocketBasedDownloader = new WebsocketBasedDownloader(debugEnabled); | |
} | |
@Override | |
public Mono<Void> convertToTts(Path inputFilePath) { | |
final var finalOutputMp3Path = inputFilePath | |
.resolveSibling(stripFilenameExtension(inputFilePath.getFileName().toString()) + ".mp3"); | |
if (exists(finalOutputMp3Path)) { | |
log.info("Output file " + finalOutputMp3Path + " already exists. Skipping download."); | |
return Mono.empty(); | |
} | |
final var textToTranscript = readAndSanitiseText(inputFilePath); | |
final List<String> transcribableChunks = splitTextIntoChunks(textToTranscript, 1_000); // tho it supports upto 8000 characters, microsoft JS client does this, lets also do the same | |
final var audioPartsDir = finalOutputMp3Path.resolveSibling("audio-parts/"); | |
// Since all File System API in java is blocking (AsynchronousFileChannel helps only with read & write, not for iteration), we have to use `subscribeOn(elastic())` hack this hack to get around that issue | |
return Mono.create(sink -> { | |
try { | |
deleteRecursively(audioPartsDir); | |
createDirectory(audioPartsDir); | |
sink.success(); | |
} catch (IOException e) { | |
sink.error(e); | |
} | |
}) | |
.subscribeOn(Schedulers.boundedElastic()) | |
.thenEmpty(websocketBasedDownloader.downloadAudioForAllChunks(transcribableChunks, audioPartsDir, inputFilePath)) | |
.then( | |
Mono.create(mergeAndCleanupSink -> { | |
// If we have more than 1 part, merge using ffmpeg | |
// TO do so | |
// 1. Create merge config file | |
// 2. ffmpeg -f concat -i merge-config.txt -c copy target.mp3 | |
// Sample merge-config.txt | |
// file 'p1.mp3' | |
// file 'p2.mp3' | |
if (transcribableChunks.size() > 1) { | |
var mergeConfigPath = audioPartsDir.resolve("merge-config.txt"); | |
var mergeConfigFileContent = IntStream.range(1, transcribableChunks.size() + 1) | |
.mapToObj(transcribableChunkNumber -> getPartFileName(transcribableChunkNumber, transcribableChunks.size())) | |
.map(partName -> "file " + partName) | |
.collect(joining("\n")); | |
final var bufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); | |
final var mergeConfigDataBuffer = bufferFactory.wrap(mergeConfigFileContent.getBytes()); | |
DataBufferUtils.write(Mono.just(mergeConfigDataBuffer), mergeConfigPath, CREATE_NEW, WRITE) | |
.doOnSubscribe(__ -> log.info("[{}]: Merging all parts from {} into {}", inputFilePath.getFileName(), audioPartsDir, finalOutputMp3Path)) | |
.doOnSubscribe(__ -> log.info("[{}]: Executing {}", inputFilePath.getFileName(), join(" ", of("ffmpeg", "-f", "concat", "-i", mergeConfigPath.toAbsolutePath().toString(), "-c", "copy", finalOutputMp3Path.toAbsolutePath().toString())))) | |
.doOnSubscribe(__ -> log.info("Process output is")) | |
.thenEmpty( | |
launchProcessMergeInputAndOutputStreamsIgnoreInputStream("ffmpeg", "-f", "concat", "-i", mergeConfigPath.toAbsolutePath().toString(), "-c", "copy", finalOutputMp3Path.toAbsolutePath().toString()) | |
.doOnNext(log::info) | |
.then() | |
) | |
.thenEmpty( | |
Mono.<Void>create(deletePartDirSink -> { | |
log.info("[{}]: Merged all parts from {} into {}", inputFilePath.getFileName(), audioPartsDir, finalOutputMp3Path); | |
try { | |
deleteRecursively(audioPartsDir); | |
log.info("[{}]: Cleaned up parts from {}", inputFilePath.getFileName(), audioPartsDir); | |
deletePartDirSink.success(); | |
} catch (IOException e) { | |
deletePartDirSink.error(e); | |
} | |
}) | |
.subscribeOn(Schedulers.boundedElastic()) | |
) | |
.subscribe(__ -> {}, mergeAndCleanupSink::error, mergeAndCleanupSink::success); | |
} else { | |
Mono.<Void>create(deletePartDirSink -> { | |
final var tempFilePath = get(() -> Files.list(audioPartsDir)).findFirst().orElseThrow(() -> new IllegalStateException("Shouldnt reach this point. File should have been downloaded by now")); | |
log.info("[{}]: Moving file {} to {}", inputFilePath.getFileName(), tempFilePath.getFileName(), finalOutputMp3Path.getFileName()); | |
run(() -> move(tempFilePath, finalOutputMp3Path)); | |
try { | |
deleteRecursively(audioPartsDir); | |
log.info("[{}]: Cleaned up parts from {}", inputFilePath.getFileName(), audioPartsDir); | |
deletePartDirSink.success(); | |
} catch (IOException e) { | |
deletePartDirSink.error(e); | |
} | |
}) | |
.subscribeOn(Schedulers.boundedElastic()) | |
.subscribe(__ -> {}, mergeAndCleanupSink::error, mergeAndCleanupSink::success); | |
} | |
}) | |
); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.acme.tts.converter.microsoft.ssml.websocket; | |
import io.netty.handler.logging.LogLevel; | |
import lombok.extern.log4j.Log4j2; | |
import reactor.core.publisher.Flux; | |
import reactor.core.publisher.Mono; | |
import reactor.core.publisher.Sinks; | |
import reactor.netty.http.client.HttpClient; | |
import reactor.netty.http.client.WebsocketClientSpec; | |
import reactor.netty.transport.logging.AdvancedByteBufFormat; | |
import java.nio.file.Path; | |
import java.time.Duration; | |
import java.time.Instant; | |
import java.util.List; | |
import java.util.concurrent.ThreadLocalRandom; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.concurrent.atomic.AtomicReference; | |
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE; | |
import static java.time.Duration.ofSeconds; | |
import static java.time.temporal.ChronoUnit.NANOS; | |
import static org.springframework.http.HttpHeaders.ORIGIN; | |
import static org.springframework.http.HttpHeaders.USER_AGENT; | |
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST; | |
/** | |
* Protocol is now deleted from microsoft website, but we can find an older version | |
* <a href="https://web.archive.org/web/20191210051654/https://docs.microsoft.com/en-us/azure/cognitive-services/speech/api-reference-rest/websocketprotocol#connection-establishment">here</a> | |
* <p> | |
* Other places of interest (documentation & reference implementation) | |
* <ol> | |
* <li><a href="https://github.com/thekalinga/MsEdgeTTS">thekalinga/MsEdgeTTS</a></li> | |
* <li><a href="https://github.com/thekalinga/ms-bing-speech-service">thekalinga/ms-bing-speech-service</a></li> | |
* <li><a href="https://github.com/thekalinga/cognitive-services-speech-sdk-js">thekalinga/cognitive-services-speech-sdk-js</a></li> | |
* </ol> | |
*/ | |
@Log4j2 | |
public class WebsocketBasedDownloader { | |
// as per spec | |
private static final int MAX_WEBSOCKET_FRAME_PAYLOAD_LENGTH = 8_192; | |
public static final String BROWSER_USER_AGENT_VALUE = | |
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36"; | |
public static final String BROWSER_ORIGIN_VALUE = "https://azure.microsoft.com"; | |
private final HttpClient client; | |
private final WebsocketClientSpec websocketClientSpec; | |
public WebsocketBasedDownloader(boolean debugEnabled) { | |
HttpClient client = HttpClient.create() | |
.headers(headers -> headers.add(USER_AGENT, BROWSER_USER_AGENT_VALUE).add(ORIGIN, BROWSER_ORIGIN_VALUE)) | |
.keepAlive(true) | |
.proxyWithSystemProperties(); | |
if(debugEnabled) { | |
client = client.wiretap(getClass().getCanonicalName(), LogLevel.TRACE, AdvancedByteBufFormat.HEX_DUMP); | |
} | |
this.client = client; | |
websocketClientSpec = WebsocketClientSpec.builder() | |
.maxFramePayloadLength(MAX_WEBSOCKET_FRAME_PAYLOAD_LENGTH) // as per spec | |
.compress(true) | |
.build(); | |
} | |
/** | |
* Downloads mp3 files all chunks specified. Here is how it works | |
* <ol> | |
* <li>Makes connection to the websocket server of microsoft azure text-to-speech server</li> | |
* <li> | |
* Sends the same request that microsoft javascript client sends | |
* <ul> | |
* <li>On outbound channel, Sends a TextWebSocketFrame to setup context (common for the whole websocket session)</li> | |
* </ul> | |
* </li> | |
* <li> | |
* For each of the input text chunk, we need to exchange multiple *WebSocketFrame that together constitute a single audio conversion | |
* <ul> | |
* <li>On outbound channel, Sends a TextWebSocketFrame to specify the settings for next SSML request | |
* <li>On outbound channel, Sends a TextWebSocketFrame with SSML xml</li> | |
* <li>On inbound channel, Receives a TextWebSocketFrame with header `Path:turn.start` which indicates we can expect response from server in subsequent frame</li> | |
* <li>On inbound channel, Receives first BinaryWebSocketFrame that specifies `Path:audio` & the 1st two bytes specify the header length & we skip that length to get audio fragment</li> | |
* <li>On inbound channel, Receives all future BinaryWebSocketFrame that specifies `Path:audio` & get audio fragment from fragment body</li> | |
* <li>On inbound channel, Receives a TextWebSocketFrame with header `Path:turn.end` indicating audio is fully sent & we can consider the audio to be complete</li> | |
* </ul> | |
* </li> | |
* <li>We repeat step (3) for every fragment with new request id</li> | |
* </ol> | |
* @param textChunks All text chunks that needs to be converted to audio. Browser client uses just 1000 characters. So, we also needs to same to avoid suspicion by microsoft | |
* @param dirToDownloadTo specify the directory to which all audio parts corresponding to the text chunks would be downloaded to | |
* @return a publisher that completes when download is complete | |
*/ | |
public Flux<Void> downloadAudioForAllChunks(List<String> textChunks, Path dirToDownloadTo, Path inputFilePath) { | |
log.debug("[{}]: Will be downloading chunks {}-{}", inputFilePath.getFileName(), 1, textChunks.size()); | |
// for debugging enable this (if you are getting an exception & dont know where its assembled at) | |
// Hooks.onOperatorDebug(); | |
// (or) if you want to deploy like this at runtime (without stacktraces), use lightweight version of `checkpoint` operator | |
// For debugging any specific rx operator (runtime behaviour as to what signals are sent up & down the pipe), use log operator | |
// log("com.acme.tts.converter.microsoft.give-meaningful-operator-name") | |
final var startedAt = new AtomicReference<Instant>(); | |
final var previousStartIndex = new AtomicInteger(0); | |
final var retryFromTextChunkIndexOnAbruptCompleteSink = Sinks.many().unicast().<Integer>onBackpressureBuffer(); | |
// lets always keep atleast 1 item in the 1st leg of retryFromTextChunkIndexOnAbruptCompleteSink, so withLatestFrom works | |
// server is sending abrupt fin (normal closure of tcp connection) after few chunks reusing same underlying websocket session | |
// In this case, we need to start a fresh flow downloading remaining segments | |
retryFromTextChunkIndexOnAbruptCompleteSink.emitNext(previousStartIndex.get(), FAIL_FAST); | |
return retryFromTextChunkIndexOnAbruptCompleteSink.asFlux().cast(Integer.class) | |
.withLatestFrom(Mono.just(textChunks), (relIndexToStartFrom, chunks) -> { | |
final var nextStartIndex = previousStartIndex.addAndGet(relIndexToStartFrom); | |
if (chunks.size() < nextStartIndex) { | |
return Mono.<List<String>>empty(); | |
} else { | |
log.info("Attempting to download {}-{} chunks in current file with new websocket session", nextStartIndex + 1, chunks.size()); | |
final var chunksRemaining = chunks.subList(nextStartIndex, chunks.size()); | |
return Mono.just(chunksRemaining) | |
.delayUntil(indexAnd_ -> { | |
if (nextStartIndex > 0) { | |
final var waitPeriodInSeconds = ThreadLocalRandom.current().nextInt(1, 10); | |
log.debug("Premptively delaying next attempt to not overwhelm the server for {} seconds", waitPeriodInSeconds); | |
final var timeTakenForAllPreviousDownloads = Duration.between(startedAt.get(), Instant.now()); | |
// remaining elements * avg time taken for all past element download + avg 5 seconds of delay per future element | |
final var estimatedTime = Duration.ofNanos((textChunks.size() - nextStartIndex - 1) * (timeTakenForAllPreviousDownloads.get(NANOS) / (nextStartIndex + 1))) | |
.plus(Duration.ofSeconds(5L * (textChunks.size() - nextStartIndex - 1))); | |
log.debug("Estimated time to complete download remaining chunks in current file is: {}", estimatedTime); | |
return Mono.delay(ofSeconds(waitPeriodInSeconds)) // lets delay each batch by 20 seconds | |
.doOnNext(___ -> log.debug("Preemptive wait complete. Resuming next batch download")); | |
} | |
return Mono.just(INSTANCE); | |
}); | |
} | |
}) // lets restart the process so we can download only pending ones | |
.doFirst(() -> startedAt.set(Instant.now())) | |
.switchMap(remainingChunks$ -> { | |
//noinspection CodeBlock2Expr | |
return remainingChunks$ | |
.flatMapMany(remainingChunks -> { | |
final var chunksAudioFetcherAndPersister = new ChunksAudioFetcherAndPersister(client, | |
websocketClientSpec, remainingChunks, retryFromTextChunkIndexOnAbruptCompleteSink, | |
dirToDownloadTo, previousStartIndex.get() + 1, textChunks.size(), inputFilePath); | |
return chunksAudioFetcherAndPersister.retrieve(); | |
}); | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment