Created
August 1, 2024 08:32
-
-
Save GibsonRuitiari/4e2d94b68ad7e41fa5043d3150ccb4c1 to your computer and use it in GitHub Desktop.
A simple and fast download manager that supports concurrent/spatial downloads
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.example | |
import kotlinx.coroutines.* | |
import okhttp3.* | |
import okhttp3.HttpUrl.Companion.toHttpUrl | |
import org.example.DownloadProgress.Companion.calculateDownloadPercentageFromDownloadProgress | |
import org.example.Utils.ATTEMPT_COUNT | |
import org.example.Utils.DedicatedBlockingDispatcher | |
import org.example.Utils.combineSegmentsToFile | |
import org.example.Utils.copyToOutputStreamAsynchronously | |
import org.example.Utils.createDownloadSegments | |
import org.example.Utils.createNetworkRequestObject | |
import org.example.Utils.defaultHttpClient | |
import org.example.Utils.deleteIfExists | |
import org.example.Utils.doesServerSupportRange | |
import org.example.Utils.executeAsynchronously | |
import org.example.Utils.stepDuration | |
import java.io.* | |
import java.net.SocketTimeoutException | |
import java.time.LocalDateTime | |
import java.time.format.DateTimeFormatter | |
import java.util.* | |
import java.util.concurrent.TimeUnit | |
import kotlin.coroutines.CoroutineContext | |
import kotlin.coroutines.EmptyCoroutineContext | |
import kotlin.coroutines.resumeWithException | |
import kotlin.time.Duration.Companion.milliseconds | |
import kotlin.time.times | |
//TIP To <b>Run</b> code, press <shortcut actionId="Run"/> or | |
// click the <icon src="AllIcons.Actions.Execute"/> icon in the gutter. | |
fun main() { | |
val downloadLinks=listOf("https://download.scdn.co/SpotifySetup.exe", | |
"https://1111-releases.cloudflareclient.com/windows/Cloudflare_WARP_Release-x64.msi", | |
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Mount_Massive.jpg/1200px-Mount_Massive.jpg", | |
"https://github.com/rainmeter/rainmeter/releases/download/v4.5.17.3700/Rainmeter-4.5.17.exe", | |
"https://releases.ubuntu.com/24.04/ubuntu-24.04-desktop-amd64.iso?_ga=2.199833447.835604183.1722166732-1299686684.1722166732&_gl=1*14ws2rw*_gcl_au*MTI3NzM5ODM2Ny4xNzIyMTY5NDEw" | |
) | |
// lacks pausing capability but supporting it should be trivial | |
runBlocking { | |
DownloadManager.downloadFile(downloadLinks[1]) | |
} | |
} | |
object DownloadManager{ | |
suspend fun downloadFile(downloadLink: String){ | |
val fileName = downloadLink.toHttpUrl().encodedPathSegments.last() | |
when(doesServerSupportRange(downloadLink)){ | |
true->{ | |
logDebug { "server has pausing capability.Doing spatial download" } | |
val networkObjectRequest = createNetworkRequestObject(downloadLink) | |
val networkResponse = defaultHttpClient().newCall(networkObjectRequest).executeAsynchronously() | |
val contentLength=networkResponse.use {response-> response.body?.contentLength() } | |
val fileSegments = createDownloadSegments(contentLength?.toInt() ?: 0).toList() | |
spatiallyDownloadCompleteFile(downloadLink,fileName,fileSegments) | |
// combine and delete files | |
combineSegmentsToFile(file = File(fileName),fileSegments.size) | |
} | |
else->{ | |
logDebug { "server lacks pausing capability. Doing sequential download" } | |
retriableSequentialDownload(downloadLink,fileName){downloadStatus -> | |
val logMessage=when(downloadStatus){ | |
is Downloading -> "downloading... ${Dumper.DOWN_TIME_ICON}${downloadStatus.downloadProgress | |
.calculateDownloadPercentageFromDownloadProgress(0)}" | |
is Failed -> "download failed because ${downloadStatus.failureReason}" | |
Finished -> "download finished" | |
Paused -> "download paused" | |
Retrying -> "retrying download" | |
} | |
logDebug { logMessage } | |
} | |
} | |
} | |
} | |
/* download file in segments using multiple connections */ | |
private suspend fun spatiallyDownloadCompleteFile(downloadLink:String, | |
downloadFileName:String, | |
fileSegments:List<Segment>) = coroutineScope{ | |
fileSegments.mapIndexed { index, segment -> | |
async(DedicatedBlockingDispatcher){ | |
retriableSpatiallyDownload(downloadLink,downloadFileName,index, segment = segment){downloadStatus-> | |
val logMessage=when(downloadStatus){ | |
is Downloading -> "segment> ${downloadStatus.fileSegment}| downloading..${downloadStatus}" | |
is Failed -> "download failed because ${downloadStatus.failureReason}" | |
Finished -> "download finished" | |
Paused -> "download paused" | |
Retrying -> "retrying download" | |
} | |
logInfo { logMessage } | |
} | |
} | |
}.awaitAll() | |
} | |
/* calls [downloadFileSegmentAndGetProgress] but adds retrying capability */ | |
private suspend fun retriableSpatiallyDownload(downloadUrl:String, | |
downloadFileName:String, | |
segmentIndex:Int, | |
segment:Segment, | |
downloadStatusUpdate:(DownloadStatus)->Unit) { | |
for (i in 0 until ATTEMPT_COUNT){ | |
try { | |
downloadFileSegmentAndGetProgress(downloadUrl, downloadFileName, segmentIndex, segment){downloadProgress -> | |
val status=Downloading(fileSegment = segmentIndex, downloadProgress = downloadProgress, | |
percentage=downloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex)) | |
downloadStatusUpdate(status) | |
return@downloadFileSegmentAndGetProgress | |
} | |
downloadStatusUpdate(Finished) | |
return | |
}catch (ex:IOException){ | |
logWarning { "Encountered the following error while doing spatial download ${ex.message}, retrying." } | |
delay(i.times(stepDuration)) | |
continue | |
} | |
} | |
} | |
/* calls [sequentiallyDownloadCompleteFile] but adds retrying capability */ | |
private suspend fun retriableSequentialDownload(downloadLink: String, | |
downloadFileName:String, | |
downloadStatusUpdate:(DownloadStatus)->Unit){ | |
for (i in 0 until ATTEMPT_COUNT){ | |
try { | |
sequentiallyDownloadCompleteFile(downloadLink, downloadFileName){downloadProgress -> | |
val status=Downloading(downloadProgress = downloadProgress, | |
percentage=downloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex =0)) | |
downloadStatusUpdate(status) | |
} | |
downloadStatusUpdate(Finished) | |
return | |
}catch (ex:IOException){ | |
logWarning { "Encountered the following error ${ex.message}, retrying." } | |
delay(i.times(stepDuration)) | |
continue | |
} | |
} | |
} | |
/* download whole file using one connection */ | |
private suspend fun sequentiallyDownloadCompleteFile(downloadLink: String, | |
downloadFileName:String, | |
downloadProgressUpdate:(DownloadProgress)->Unit){ | |
val networkRequestObj = createNetworkRequestObject(downloadLink) | |
val networkResponse = defaultHttpClient().newCall(networkRequestObj).executeAsynchronously() | |
val downloadFile = File(downloadFileName) | |
/* such downloads are not resume-able and server cannot give us partial content | |
so the initial partially downloaded file is of no use to use, delete it and start a fresh | |
*/ | |
downloadFile.deleteIfExists() | |
networkResponse.body?.downloadByteStreamToFileWithProgress(downloadFile,downloadProgressUpdate) | |
} | |
/* download just one segment of a file */ | |
private suspend fun downloadFileSegmentAndGetProgress(downloadUrl:String, | |
downloadFileName:String, | |
segmentIndex:Int, | |
segment:Segment, | |
downloadProgressUpdate:(DownloadProgress)->Unit){ | |
val networkObjectRequest = createNetworkRequestObject(downloadUrl, | |
headers = mapOf("Range" to "bytes=${segment.segmentStartPosition}-${segment.segmentEndPosition}")) | |
val networkResponse = defaultHttpClient().newCall(networkObjectRequest).executeAsynchronously() | |
val downloadFile = File("${segmentIndex}_$downloadFileName") | |
networkResponse.body?.downloadByteStreamToFileWithProgress(downloadFile,downloadProgressUpdate) | |
} | |
private suspend fun ResponseBody.downloadByteStreamToFileWithProgress(recipientFile:File, | |
downloadProgressUpdate:(DownloadProgress)->Unit) { | |
val contentLength = contentLength().toInt() | |
// scope of output-stream must outlive input-stream's scope | |
recipientFile.outputStream().use {output-> | |
byteStream().use {input-> | |
// initial download progress | |
var lastSavedProgress = 0 | |
var downloadProgress =DownloadProgress(0,contentLength) | |
// capturing lambda... hmm can we eliminate this to avoid object allocation overhead? | |
input.copyToOutputStreamAsynchronously(output){bytesRead: Int -> | |
val progress= ((bytesRead.toDouble()/contentLength)*100).toInt() | |
if (bytesRead>lastSavedProgress){ | |
lastSavedProgress = progress | |
downloadProgress=DownloadProgress(bytesRead,contentLength) | |
downloadProgressUpdate(downloadProgress) | |
} | |
} | |
downloadProgressUpdate(downloadProgress) | |
} | |
} | |
} | |
} | |
/* start-end==offset*/ | |
data class Segment(val segmentStartPosition:Int,val segmentEndPosition:Int) | |
data class DownloadProgress(val bytesRead:Int, val totalBytesToRead:Int){ | |
companion object{ | |
private var segmentProgressMap = mutableMapOf<Int,Int>() | |
fun DownloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex:Int):String{ | |
val progress= ((bytesRead.toDouble()/totalBytesToRead)*100).toInt() | |
segmentProgressMap[segmentIndex] = maxOf(segmentProgressMap.getOrDefault(segmentIndex, 0), progress) | |
return "${segmentProgressMap[segmentIndex]}%" | |
} | |
} | |
} | |
sealed class DownloadStatus | |
data object Finished:DownloadStatus() | |
data object Paused:DownloadStatus() | |
data class Failed(val failureReason:String):DownloadStatus() | |
data object Retrying:DownloadStatus() | |
data class Downloading(val fileSegment:Int?=null,val downloadProgress: DownloadProgress, val percentage:String):DownloadStatus() | |
object Utils{ | |
// max wait per attempt 100 * 10ms = 1second | |
const val ATTEMPT_COUNT = 100 | |
val stepDuration: kotlin.time.Duration = 10.milliseconds | |
val DedicatedBlockingDispatcher = Dispatchers.IO.limitedParallelism(Int.MAX_VALUE) | |
private const val MINIMUM_STREAM_SIZE=1024*1024 | |
/*number of coroutines/threads (if this was implemented using threads) to create | |
* can be between 2-32 but the better one is 8, don't go beyond 8 because server will throttle you*/ | |
private const val THREAD_CONNECTIONS = 4 | |
suspend fun createDownloadSegments(responseBodyLength:Int, | |
preferredThreadConnections:Int= THREAD_CONNECTIONS):Sequence<Segment>{ | |
return sequence { | |
for(i in 0 until preferredThreadConnections){ | |
val startPosition = i * (responseBodyLength/preferredThreadConnections) | |
val endPosition = startPosition + (responseBodyLength/preferredThreadConnections)-1 | |
yield(Segment(startPosition,endPosition)) | |
} | |
} | |
} | |
fun createNetworkRequestObject(url:String,headers:Map<String,String> = emptyMap()):Request{ | |
val networkRequestBuilder=Request.Builder() | |
.url(url) | |
if (headers.isNotEmpty()) headers.forEach{(key,value)-> networkRequestBuilder.addHeader(key,value)} | |
return networkRequestBuilder.build() | |
} | |
fun defaultHttpClient() = OkHttpClient.Builder() | |
.connectTimeout(40,TimeUnit.SECONDS) | |
.readTimeout(40,TimeUnit.SECONDS) | |
.writeTimeout(40,TimeUnit.SECONDS) | |
.followRedirects(true) | |
.protocols(Collections.singletonList(Protocol.HTTP_1_1)) | |
.addInterceptor {chain-> | |
val userAgentValue="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36" | |
val newRequest=chain.request().newBuilder().addHeader("User-Agent",userAgentValue) | |
.build() | |
chain.proceed(newRequest) | |
} | |
.build() | |
suspend fun Call.executeAsynchronously() = suspendCancellableCoroutine<Response> {cancellableContinuation -> | |
val callback=object :Callback{ | |
override fun onFailure(call: Call, e: IOException) { | |
cancellableContinuation.resumeWithException(e) | |
} | |
override fun onResponse(call: Call, response: Response) { | |
if (response.isSuccessful){ | |
cancellableContinuation.resume(response){cause: Throwable, value: Response, _: CoroutineContext -> | |
logWarning { "exception occurred while performing network request ${cause.message}" } | |
cancel() | |
} | |
}else{ | |
cancellableContinuation.resumeWithException(IOException("Network request was not successful")) | |
} | |
} | |
} | |
enqueue(callback) | |
cancellableContinuation.invokeOnCancellation { cancel() } | |
} | |
suspend fun combineSegmentsToFile(file:File,segmentSize:Int){ | |
file.outputStream().use {out-> | |
// sequentially go over the segments combining them one by one | |
// has to be done sequentially to observe and maintain file integrity | |
repeat(segmentSize){segment-> | |
val temporaryFile = File("${segment}_${file.name}") | |
if (!temporaryFile.exists()){ | |
logWarning { "temporary file does not exist. Aborting!" } | |
// something is wrong. not all segments were download, file integrity is compromised, abort! | |
throw IOException("Error while combining segments to file. not all segments were downloaded! retry again!") | |
} | |
temporaryFile.inputStream().use {input-> input.copyToOutputStreamAsynchronously(out)} | |
temporaryFile.deleteIfExists() | |
} | |
} | |
} | |
fun File.deleteIfExists():Boolean = try { | |
delete() | |
true | |
}catch (ex:IOException){ | |
logError(throwable = ex) | |
false | |
} | |
suspend fun doesServerSupportRange(urlToFollow:String):Boolean{ | |
val networkRequestObject = createNetworkRequestObject(urlToFollow, headers = mapOf("Range" to "bytes=0-0")) | |
return try { | |
val networkResponse = defaultHttpClient().newCall(networkRequestObject).executeAsynchronously() | |
networkResponse.use {response-> | |
response.code == 206 } | |
}catch (ex:IOException){ | |
// propagate the error to caller function for it to be handled by top-level/primary | |
// function. | |
logWarning { "Does Server Support Range: the following exception was encountered ${ex.message} " } | |
throw ex | |
} | |
} | |
/* computes an intense IO op. which might block for an indefinite amount of time eg updating dbs or downloading | |
* a huge file from network. Designed to work with calls that might hang e.g., Java's input-stream calls | |
* hang indefinitely. Using Dispatchers.IO would not be a good idea since that would lead to starvation*/ | |
@OptIn(DelicateCoroutinesApi::class) | |
private suspend fun<T> performIntenseIO(context:CoroutineContext=EmptyCoroutineContext, | |
action:suspend CoroutineScope.()->T):T{ | |
/* dispatcher is unbounded so IO dispatchers cannot be starved | |
* the action does not inherit the caller's context thus it can be managed independent of calling coroutine */ | |
val deferredAction = GlobalScope.async(DedicatedBlockingDispatcher+context, block = action) | |
return try { | |
deferredAction.await() | |
}catch (ex:CancellationException){ | |
/* breaks structured concurrency, we are not waiting for action to finish. If this coroutine is cancelled | |
* the action too thus preventing resource leak esp. where the download hangs indefinitely or | |
* where the user cancels the download attempt, we don't have to wait for it finish*/ | |
deferredAction.cancel(ex) | |
throw ex | |
} | |
} | |
/* a non-blocking version of Kotlin's copyTo. Won't block the calling coroutine even for a minute | |
the blocking op. is executed on a daemon like thread (or coroutine?) so it might outlive the caller's context | |
*/ | |
suspend fun InputStream.copyToOutputStreamAsynchronously(output:OutputStream, | |
bufferSize:Int= DEFAULT_BUFFER_SIZE, | |
limit:Long= Long.MAX_VALUE, | |
onUpdate:((bytesRead:Int)->Unit)?=null){ | |
performIntenseIO(context = CoroutineName("copyToAsync:$this => $output")){ | |
val buffer = ByteArray(bufferSize) | |
var totalRead=0L | |
var lastProgress =0 | |
while (totalRead<limit){ | |
yield() // prevent one coroutine controlling the cpu | |
val read = try { | |
read(buffer,0, minOf(limit-totalRead,buffer.size.toLong()).toInt()) | |
}catch (_:SocketTimeoutException){continue} | |
when{ | |
read <0 -> break // EOF no more bytes to read | |
read >0 ->{ | |
totalRead += read | |
val progress=totalRead.toInt() | |
if (progress>lastProgress){ | |
onUpdate?.let { it(progress) } | |
lastProgress = progress | |
} | |
yield() | |
output.write(buffer,0,read) | |
} | |
else -> Unit | |
} | |
} | |
} | |
} | |
} | |
object Dumper{ | |
const val UP_TIME_ICON="⇡" | |
const val DOWN_TIME_ICON="⇣" | |
private const val BOLD_ERROR_CODE="\u001B[1;31m" | |
private const val BACKGROUND_RED_CODE = "\u001B[41m" | |
private const val ERROR_CHARACTER_CODE='E' | |
private const val INFO_CHARACTER_CODE='I' | |
private const val DEBUG_CHARACTER_CODE='D' | |
private const val WARNING_CHARACTER_CODE='W' | |
private const val RESET_CODE="\u001B[0m" | |
private const val BACKGROUND_YELLOW_CODE = "\u001B[43m" | |
private const val FOREGROUND_YELLOW_CODE = "\u001B[33m" | |
private const val BACKGROUND_BLUE_CODE="\u001B[104m" | |
private const val FOREGROUND_BLACK_CODE = "\u001B[30m" | |
private const val FOREGROUND_BLUE_CODE = "\u001B[94m" | |
private const val BACKGROUND_GREEN_CODE="\u001b[42m" | |
private const val FOREGROUND_GREEN_CODE="\u001b[32m" | |
private const val FOREGROUND_WHITE_CODE = "\u001B[97m" // \u001B[1;37m | |
private fun currentDate():String{ | |
val dateFormat="yyyy-mm-dd hh:mm" | |
return DateTimeFormatter.ofPattern(dateFormat).format(LocalDateTime.now()) | |
} | |
private fun breather(len:Int=2) = " ".repeat(len) | |
fun logWarning(message:()->String){ | |
val msg = message() | |
val formattedErrorMessage = buildString { | |
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE") | |
append(breather()) | |
append("$BACKGROUND_YELLOW_CODE$FOREGROUND_BLACK_CODE $WARNING_CHARACTER_CODE $RESET_CODE") | |
append(breather()) | |
append("$FOREGROUND_YELLOW_CODE$msg$RESET_CODE") | |
} | |
val printStream = PrintStream(System.out) | |
printStream.println(formattedErrorMessage) | |
printStream.flush() | |
} | |
fun logDebug(message:()->String){ | |
val msg = message() | |
val formattedErrorMessage = buildString { | |
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE") | |
append(breather()) | |
append("$BACKGROUND_GREEN_CODE$FOREGROUND_WHITE_CODE $DEBUG_CHARACTER_CODE $RESET_CODE") | |
append(breather()) | |
append("$FOREGROUND_GREEN_CODE$msg$RESET_CODE") | |
} | |
val printStream = PrintStream(System.out) | |
printStream.println(formattedErrorMessage) | |
printStream.flush() | |
} | |
fun logInfo(message:()->String){ | |
val msg = message() | |
val formattedErrorMessage = buildString { | |
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE") | |
append(breather()) | |
append("$BACKGROUND_BLUE_CODE$FOREGROUND_WHITE_CODE $INFO_CHARACTER_CODE $RESET_CODE") | |
append(breather()) | |
append("$FOREGROUND_BLUE_CODE$msg$RESET_CODE") | |
} | |
val printStream = PrintStream(System.out) | |
printStream.println(formattedErrorMessage) | |
printStream.flush() | |
} | |
fun logError(throwable:Throwable,message:(()->String)?=null) { | |
val errorMessage = throwable.message ?: message?.let { it() } ?: "" | |
val formattedErrorMessage = buildString { | |
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE") | |
append(breather()) | |
append("$BACKGROUND_RED_CODE$FOREGROUND_BLACK_CODE $ERROR_CHARACTER_CODE $RESET_CODE") | |
append(breather()) | |
append("$BOLD_ERROR_CODE$errorMessage$RESET_CODE") | |
append(breather()) | |
} | |
val printStream = PrintStream(System.err) | |
printStream.println(formattedErrorMessage) | |
printStream.flush() | |
} | |
} | |
// dumper convenient functions | |
fun logError(throwable: Throwable,message: (() -> String)?=null) = Dumper.logError(throwable, message) | |
fun logWarning(message: () -> String) = Dumper.logWarning(message) | |
fun logDebug(message: () -> String) = Dumper.logDebug(message) | |
fun logInfo(message: () -> String) = Dumper.logInfo(message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment