Skip to content

Instantly share code, notes, and snippets.

@GibsonRuitiari
Created August 1, 2024 08:32
Show Gist options
  • Save GibsonRuitiari/4e2d94b68ad7e41fa5043d3150ccb4c1 to your computer and use it in GitHub Desktop.
Save GibsonRuitiari/4e2d94b68ad7e41fa5043d3150ccb4c1 to your computer and use it in GitHub Desktop.
A simple and fast download manager that supports concurrent/spatial downloads
package org.example
import kotlinx.coroutines.*
import okhttp3.*
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.example.DownloadProgress.Companion.calculateDownloadPercentageFromDownloadProgress
import org.example.Utils.ATTEMPT_COUNT
import org.example.Utils.DedicatedBlockingDispatcher
import org.example.Utils.combineSegmentsToFile
import org.example.Utils.copyToOutputStreamAsynchronously
import org.example.Utils.createDownloadSegments
import org.example.Utils.createNetworkRequestObject
import org.example.Utils.defaultHttpClient
import org.example.Utils.deleteIfExists
import org.example.Utils.doesServerSupportRange
import org.example.Utils.executeAsynchronously
import org.example.Utils.stepDuration
import java.io.*
import java.net.SocketTimeoutException
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.resumeWithException
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.times
//TIP To <b>Run</b> code, press <shortcut actionId="Run"/> or
// click the <icon src="AllIcons.Actions.Execute"/> icon in the gutter.
fun main() {
val downloadLinks=listOf("https://download.scdn.co/SpotifySetup.exe",
"https://1111-releases.cloudflareclient.com/windows/Cloudflare_WARP_Release-x64.msi",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Mount_Massive.jpg/1200px-Mount_Massive.jpg",
"https://github.com/rainmeter/rainmeter/releases/download/v4.5.17.3700/Rainmeter-4.5.17.exe",
"https://releases.ubuntu.com/24.04/ubuntu-24.04-desktop-amd64.iso?_ga=2.199833447.835604183.1722166732-1299686684.1722166732&_gl=1*14ws2rw*_gcl_au*MTI3NzM5ODM2Ny4xNzIyMTY5NDEw"
)
// lacks pausing capability but supporting it should be trivial
runBlocking {
DownloadManager.downloadFile(downloadLinks[1])
}
}
object DownloadManager{
suspend fun downloadFile(downloadLink: String){
val fileName = downloadLink.toHttpUrl().encodedPathSegments.last()
when(doesServerSupportRange(downloadLink)){
true->{
logDebug { "server has pausing capability.Doing spatial download" }
val networkObjectRequest = createNetworkRequestObject(downloadLink)
val networkResponse = defaultHttpClient().newCall(networkObjectRequest).executeAsynchronously()
val contentLength=networkResponse.use {response-> response.body?.contentLength() }
val fileSegments = createDownloadSegments(contentLength?.toInt() ?: 0).toList()
spatiallyDownloadCompleteFile(downloadLink,fileName,fileSegments)
// combine and delete files
combineSegmentsToFile(file = File(fileName),fileSegments.size)
}
else->{
logDebug { "server lacks pausing capability. Doing sequential download" }
retriableSequentialDownload(downloadLink,fileName){downloadStatus ->
val logMessage=when(downloadStatus){
is Downloading -> "downloading... ${Dumper.DOWN_TIME_ICON}${downloadStatus.downloadProgress
.calculateDownloadPercentageFromDownloadProgress(0)}"
is Failed -> "download failed because ${downloadStatus.failureReason}"
Finished -> "download finished"
Paused -> "download paused"
Retrying -> "retrying download"
}
logDebug { logMessage }
}
}
}
}
/* download file in segments using multiple connections */
private suspend fun spatiallyDownloadCompleteFile(downloadLink:String,
downloadFileName:String,
fileSegments:List<Segment>) = coroutineScope{
fileSegments.mapIndexed { index, segment ->
async(DedicatedBlockingDispatcher){
retriableSpatiallyDownload(downloadLink,downloadFileName,index, segment = segment){downloadStatus->
val logMessage=when(downloadStatus){
is Downloading -> "segment> ${downloadStatus.fileSegment}| downloading..${downloadStatus}"
is Failed -> "download failed because ${downloadStatus.failureReason}"
Finished -> "download finished"
Paused -> "download paused"
Retrying -> "retrying download"
}
logInfo { logMessage }
}
}
}.awaitAll()
}
/* calls [downloadFileSegmentAndGetProgress] but adds retrying capability */
private suspend fun retriableSpatiallyDownload(downloadUrl:String,
downloadFileName:String,
segmentIndex:Int,
segment:Segment,
downloadStatusUpdate:(DownloadStatus)->Unit) {
for (i in 0 until ATTEMPT_COUNT){
try {
downloadFileSegmentAndGetProgress(downloadUrl, downloadFileName, segmentIndex, segment){downloadProgress ->
val status=Downloading(fileSegment = segmentIndex, downloadProgress = downloadProgress,
percentage=downloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex))
downloadStatusUpdate(status)
return@downloadFileSegmentAndGetProgress
}
downloadStatusUpdate(Finished)
return
}catch (ex:IOException){
logWarning { "Encountered the following error while doing spatial download ${ex.message}, retrying." }
delay(i.times(stepDuration))
continue
}
}
}
/* calls [sequentiallyDownloadCompleteFile] but adds retrying capability */
private suspend fun retriableSequentialDownload(downloadLink: String,
downloadFileName:String,
downloadStatusUpdate:(DownloadStatus)->Unit){
for (i in 0 until ATTEMPT_COUNT){
try {
sequentiallyDownloadCompleteFile(downloadLink, downloadFileName){downloadProgress ->
val status=Downloading(downloadProgress = downloadProgress,
percentage=downloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex =0))
downloadStatusUpdate(status)
}
downloadStatusUpdate(Finished)
return
}catch (ex:IOException){
logWarning { "Encountered the following error ${ex.message}, retrying." }
delay(i.times(stepDuration))
continue
}
}
}
/* download whole file using one connection */
private suspend fun sequentiallyDownloadCompleteFile(downloadLink: String,
downloadFileName:String,
downloadProgressUpdate:(DownloadProgress)->Unit){
val networkRequestObj = createNetworkRequestObject(downloadLink)
val networkResponse = defaultHttpClient().newCall(networkRequestObj).executeAsynchronously()
val downloadFile = File(downloadFileName)
/* such downloads are not resume-able and server cannot give us partial content
so the initial partially downloaded file is of no use to use, delete it and start a fresh
*/
downloadFile.deleteIfExists()
networkResponse.body?.downloadByteStreamToFileWithProgress(downloadFile,downloadProgressUpdate)
}
/* download just one segment of a file */
private suspend fun downloadFileSegmentAndGetProgress(downloadUrl:String,
downloadFileName:String,
segmentIndex:Int,
segment:Segment,
downloadProgressUpdate:(DownloadProgress)->Unit){
val networkObjectRequest = createNetworkRequestObject(downloadUrl,
headers = mapOf("Range" to "bytes=${segment.segmentStartPosition}-${segment.segmentEndPosition}"))
val networkResponse = defaultHttpClient().newCall(networkObjectRequest).executeAsynchronously()
val downloadFile = File("${segmentIndex}_$downloadFileName")
networkResponse.body?.downloadByteStreamToFileWithProgress(downloadFile,downloadProgressUpdate)
}
private suspend fun ResponseBody.downloadByteStreamToFileWithProgress(recipientFile:File,
downloadProgressUpdate:(DownloadProgress)->Unit) {
val contentLength = contentLength().toInt()
// scope of output-stream must outlive input-stream's scope
recipientFile.outputStream().use {output->
byteStream().use {input->
// initial download progress
var lastSavedProgress = 0
var downloadProgress =DownloadProgress(0,contentLength)
// capturing lambda... hmm can we eliminate this to avoid object allocation overhead?
input.copyToOutputStreamAsynchronously(output){bytesRead: Int ->
val progress= ((bytesRead.toDouble()/contentLength)*100).toInt()
if (bytesRead>lastSavedProgress){
lastSavedProgress = progress
downloadProgress=DownloadProgress(bytesRead,contentLength)
downloadProgressUpdate(downloadProgress)
}
}
downloadProgressUpdate(downloadProgress)
}
}
}
}
/* start-end==offset*/
data class Segment(val segmentStartPosition:Int,val segmentEndPosition:Int)
data class DownloadProgress(val bytesRead:Int, val totalBytesToRead:Int){
companion object{
private var segmentProgressMap = mutableMapOf<Int,Int>()
fun DownloadProgress.calculateDownloadPercentageFromDownloadProgress(segmentIndex:Int):String{
val progress= ((bytesRead.toDouble()/totalBytesToRead)*100).toInt()
segmentProgressMap[segmentIndex] = maxOf(segmentProgressMap.getOrDefault(segmentIndex, 0), progress)
return "${segmentProgressMap[segmentIndex]}%"
}
}
}
sealed class DownloadStatus
data object Finished:DownloadStatus()
data object Paused:DownloadStatus()
data class Failed(val failureReason:String):DownloadStatus()
data object Retrying:DownloadStatus()
data class Downloading(val fileSegment:Int?=null,val downloadProgress: DownloadProgress, val percentage:String):DownloadStatus()
object Utils{
// max wait per attempt 100 * 10ms = 1second
const val ATTEMPT_COUNT = 100
val stepDuration: kotlin.time.Duration = 10.milliseconds
val DedicatedBlockingDispatcher = Dispatchers.IO.limitedParallelism(Int.MAX_VALUE)
private const val MINIMUM_STREAM_SIZE=1024*1024
/*number of coroutines/threads (if this was implemented using threads) to create
* can be between 2-32 but the better one is 8, don't go beyond 8 because server will throttle you*/
private const val THREAD_CONNECTIONS = 4
suspend fun createDownloadSegments(responseBodyLength:Int,
preferredThreadConnections:Int= THREAD_CONNECTIONS):Sequence<Segment>{
return sequence {
for(i in 0 until preferredThreadConnections){
val startPosition = i * (responseBodyLength/preferredThreadConnections)
val endPosition = startPosition + (responseBodyLength/preferredThreadConnections)-1
yield(Segment(startPosition,endPosition))
}
}
}
fun createNetworkRequestObject(url:String,headers:Map<String,String> = emptyMap()):Request{
val networkRequestBuilder=Request.Builder()
.url(url)
if (headers.isNotEmpty()) headers.forEach{(key,value)-> networkRequestBuilder.addHeader(key,value)}
return networkRequestBuilder.build()
}
fun defaultHttpClient() = OkHttpClient.Builder()
.connectTimeout(40,TimeUnit.SECONDS)
.readTimeout(40,TimeUnit.SECONDS)
.writeTimeout(40,TimeUnit.SECONDS)
.followRedirects(true)
.protocols(Collections.singletonList(Protocol.HTTP_1_1))
.addInterceptor {chain->
val userAgentValue="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36"
val newRequest=chain.request().newBuilder().addHeader("User-Agent",userAgentValue)
.build()
chain.proceed(newRequest)
}
.build()
suspend fun Call.executeAsynchronously() = suspendCancellableCoroutine<Response> {cancellableContinuation ->
val callback=object :Callback{
override fun onFailure(call: Call, e: IOException) {
cancellableContinuation.resumeWithException(e)
}
override fun onResponse(call: Call, response: Response) {
if (response.isSuccessful){
cancellableContinuation.resume(response){cause: Throwable, value: Response, _: CoroutineContext ->
logWarning { "exception occurred while performing network request ${cause.message}" }
cancel()
}
}else{
cancellableContinuation.resumeWithException(IOException("Network request was not successful"))
}
}
}
enqueue(callback)
cancellableContinuation.invokeOnCancellation { cancel() }
}
suspend fun combineSegmentsToFile(file:File,segmentSize:Int){
file.outputStream().use {out->
// sequentially go over the segments combining them one by one
// has to be done sequentially to observe and maintain file integrity
repeat(segmentSize){segment->
val temporaryFile = File("${segment}_${file.name}")
if (!temporaryFile.exists()){
logWarning { "temporary file does not exist. Aborting!" }
// something is wrong. not all segments were download, file integrity is compromised, abort!
throw IOException("Error while combining segments to file. not all segments were downloaded! retry again!")
}
temporaryFile.inputStream().use {input-> input.copyToOutputStreamAsynchronously(out)}
temporaryFile.deleteIfExists()
}
}
}
fun File.deleteIfExists():Boolean = try {
delete()
true
}catch (ex:IOException){
logError(throwable = ex)
false
}
suspend fun doesServerSupportRange(urlToFollow:String):Boolean{
val networkRequestObject = createNetworkRequestObject(urlToFollow, headers = mapOf("Range" to "bytes=0-0"))
return try {
val networkResponse = defaultHttpClient().newCall(networkRequestObject).executeAsynchronously()
networkResponse.use {response->
response.code == 206 }
}catch (ex:IOException){
// propagate the error to caller function for it to be handled by top-level/primary
// function.
logWarning { "Does Server Support Range: the following exception was encountered ${ex.message} " }
throw ex
}
}
/* computes an intense IO op. which might block for an indefinite amount of time eg updating dbs or downloading
* a huge file from network. Designed to work with calls that might hang e.g., Java's input-stream calls
* hang indefinitely. Using Dispatchers.IO would not be a good idea since that would lead to starvation*/
@OptIn(DelicateCoroutinesApi::class)
private suspend fun<T> performIntenseIO(context:CoroutineContext=EmptyCoroutineContext,
action:suspend CoroutineScope.()->T):T{
/* dispatcher is unbounded so IO dispatchers cannot be starved
* the action does not inherit the caller's context thus it can be managed independent of calling coroutine */
val deferredAction = GlobalScope.async(DedicatedBlockingDispatcher+context, block = action)
return try {
deferredAction.await()
}catch (ex:CancellationException){
/* breaks structured concurrency, we are not waiting for action to finish. If this coroutine is cancelled
* the action too thus preventing resource leak esp. where the download hangs indefinitely or
* where the user cancels the download attempt, we don't have to wait for it finish*/
deferredAction.cancel(ex)
throw ex
}
}
/* a non-blocking version of Kotlin's copyTo. Won't block the calling coroutine even for a minute
the blocking op. is executed on a daemon like thread (or coroutine?) so it might outlive the caller's context
*/
suspend fun InputStream.copyToOutputStreamAsynchronously(output:OutputStream,
bufferSize:Int= DEFAULT_BUFFER_SIZE,
limit:Long= Long.MAX_VALUE,
onUpdate:((bytesRead:Int)->Unit)?=null){
performIntenseIO(context = CoroutineName("copyToAsync:$this => $output")){
val buffer = ByteArray(bufferSize)
var totalRead=0L
var lastProgress =0
while (totalRead<limit){
yield() // prevent one coroutine controlling the cpu
val read = try {
read(buffer,0, minOf(limit-totalRead,buffer.size.toLong()).toInt())
}catch (_:SocketTimeoutException){continue}
when{
read <0 -> break // EOF no more bytes to read
read >0 ->{
totalRead += read
val progress=totalRead.toInt()
if (progress>lastProgress){
onUpdate?.let { it(progress) }
lastProgress = progress
}
yield()
output.write(buffer,0,read)
}
else -> Unit
}
}
}
}
}
object Dumper{
const val UP_TIME_ICON="⇡"
const val DOWN_TIME_ICON="⇣"
private const val BOLD_ERROR_CODE="\u001B[1;31m"
private const val BACKGROUND_RED_CODE = "\u001B[41m"
private const val ERROR_CHARACTER_CODE='E'
private const val INFO_CHARACTER_CODE='I'
private const val DEBUG_CHARACTER_CODE='D'
private const val WARNING_CHARACTER_CODE='W'
private const val RESET_CODE="\u001B[0m"
private const val BACKGROUND_YELLOW_CODE = "\u001B[43m"
private const val FOREGROUND_YELLOW_CODE = "\u001B[33m"
private const val BACKGROUND_BLUE_CODE="\u001B[104m"
private const val FOREGROUND_BLACK_CODE = "\u001B[30m"
private const val FOREGROUND_BLUE_CODE = "\u001B[94m"
private const val BACKGROUND_GREEN_CODE="\u001b[42m"
private const val FOREGROUND_GREEN_CODE="\u001b[32m"
private const val FOREGROUND_WHITE_CODE = "\u001B[97m" // \u001B[1;37m
private fun currentDate():String{
val dateFormat="yyyy-mm-dd hh:mm"
return DateTimeFormatter.ofPattern(dateFormat).format(LocalDateTime.now())
}
private fun breather(len:Int=2) = " ".repeat(len)
fun logWarning(message:()->String){
val msg = message()
val formattedErrorMessage = buildString {
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE")
append(breather())
append("$BACKGROUND_YELLOW_CODE$FOREGROUND_BLACK_CODE $WARNING_CHARACTER_CODE $RESET_CODE")
append(breather())
append("$FOREGROUND_YELLOW_CODE$msg$RESET_CODE")
}
val printStream = PrintStream(System.out)
printStream.println(formattedErrorMessage)
printStream.flush()
}
fun logDebug(message:()->String){
val msg = message()
val formattedErrorMessage = buildString {
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE")
append(breather())
append("$BACKGROUND_GREEN_CODE$FOREGROUND_WHITE_CODE $DEBUG_CHARACTER_CODE $RESET_CODE")
append(breather())
append("$FOREGROUND_GREEN_CODE$msg$RESET_CODE")
}
val printStream = PrintStream(System.out)
printStream.println(formattedErrorMessage)
printStream.flush()
}
fun logInfo(message:()->String){
val msg = message()
val formattedErrorMessage = buildString {
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE")
append(breather())
append("$BACKGROUND_BLUE_CODE$FOREGROUND_WHITE_CODE $INFO_CHARACTER_CODE $RESET_CODE")
append(breather())
append("$FOREGROUND_BLUE_CODE$msg$RESET_CODE")
}
val printStream = PrintStream(System.out)
printStream.println(formattedErrorMessage)
printStream.flush()
}
fun logError(throwable:Throwable,message:(()->String)?=null) {
val errorMessage = throwable.message ?: message?.let { it() } ?: ""
val formattedErrorMessage = buildString {
append("$FOREGROUND_WHITE_CODE${currentDate()} $RESET_CODE")
append(breather())
append("$BACKGROUND_RED_CODE$FOREGROUND_BLACK_CODE $ERROR_CHARACTER_CODE $RESET_CODE")
append(breather())
append("$BOLD_ERROR_CODE$errorMessage$RESET_CODE")
append(breather())
}
val printStream = PrintStream(System.err)
printStream.println(formattedErrorMessage)
printStream.flush()
}
}
// dumper convenient functions
fun logError(throwable: Throwable,message: (() -> String)?=null) = Dumper.logError(throwable, message)
fun logWarning(message: () -> String) = Dumper.logWarning(message)
fun logDebug(message: () -> String) = Dumper.logDebug(message)
fun logInfo(message: () -> String) = Dumper.logInfo(message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment