Last active
April 30, 2018 16:45
-
-
Save pm-dev/b1eb3e6dda3ad5064c81be44e458286a to your computer and use it in GitHub Desktop.
Better type-safety with units in Kotlin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Here's how unit type-safety can prevent you from performing invalid unit arithmetic | |
fun main(args: Array<String>) { | |
val oranges = Oranges(4) | |
val apples = Apples(4) | |
// Will result in compile time error: | |
println("Can't compare apples to oranges: ${apples == oranges}") | |
println("Can't add two unknown things to my apples: ${apples + 2}") | |
println("Can't divide apples by apples: ${apples / apples}") | |
// Valid syntax | |
println("Comparing apples to apples: ${apples == apples }") // Comparing apples to apples: true | |
println("Adding apples to apples: ${apples + apples}") // "Adding apples to apples: 8(Apples)" | |
println("Dividing apples by 4: ${apples / 4.0}") // "Dividing apples by 4: 1(Apples)" | |
} | |
// Here's how you declare your unit: | |
class Oranges(value: Int): IntUnit<Oranges>(value, { Oranges(it) }) | |
class Apples(value: Int): IntUnit<Apples>(value, { Apples(it) }) | |
// Here's the classes that make it happen: | |
abstract class Unit<TYPE: Number, THIS: Unit<TYPE, THIS>>( | |
protected val value: TYPE | |
): Number(), Comparable<THIS> { | |
@Suppress("UNCHECKED_CAST") | |
override fun equals(other: Any?) = javaClass.isInstance(other) && compareTo(other as THIS) == 0 | |
override fun hashCode() = value.hashCode() xor javaClass.hashCode() | |
override fun toString() = "$value ${javaClass.simpleName}" | |
override fun toInt() = value.toInt() | |
override fun toLong() = value.toLong() | |
override fun toDouble() = value.toDouble() | |
override fun toFloat() = value.toFloat() | |
override fun toByte() = value.toByte() | |
override fun toChar() = value.toChar() | |
override fun toShort() = value.toShort() | |
} | |
abstract class IntUnit<T: IntUnit<T>>( | |
value: Int, | |
private val constructor: (Int) -> T | |
): Unit<Int, T>(value) { | |
override fun compareTo(other: T) = value.compareTo(other.value) | |
operator fun inc() = constructor(value.inc()) | |
operator fun dec() = constructor(value.dec()) | |
operator fun unaryMinus() = constructor(value.unaryMinus()) | |
operator fun unaryPlus() = constructor(value.unaryPlus()) | |
operator fun times(other: Byte) = constructor(value.times(other)) | |
operator fun times(other: Double) = constructor(value.times(other).toInt()) | |
operator fun times(other: Float) = constructor(value.times(other).toInt()) | |
operator fun times(other: Int) = constructor(value.times(other)) | |
operator fun times(other: Long) = constructor(value.times(other).toInt()) | |
operator fun times(other: Short) = constructor(value.times(other)) | |
operator fun div(other: Byte) = constructor(value.div(other)) | |
operator fun div(other: Double) = constructor(value.div(other).toInt()) | |
operator fun div(other: Float) = constructor(value.div(other).toInt()) | |
operator fun div(other: Int) = constructor(value.div(other)) | |
operator fun div(other: Long) = constructor(value.div(other).toInt()) | |
operator fun div(other: Short) = constructor(value.div(other)) | |
operator fun rem(other: Byte) = constructor(value.rem(other)) | |
operator fun rem(other: Double) = constructor(value.rem(other).toInt()) | |
operator fun rem(other: Float) = constructor(value.rem(other).toInt()) | |
operator fun rem(other: Int) = constructor(value.rem(other)) | |
operator fun rem(other: Long) = constructor(value.rem(other).toInt()) | |
operator fun rem(other: Short) = constructor(value.rem(other)) | |
operator fun plus(other: T) = constructor(value.plus(other.value)) | |
operator fun minus(other: T) = constructor(value.minus(other.value)) | |
} | |
abstract class LongUnit<T: LongUnit<T>>( | |
value: Long, | |
private val constructor: (Long) -> T | |
): Unit<Long, T>(value) { | |
override fun compareTo(other: T) = value.compareTo(other.value) | |
operator fun inc() = constructor(value.inc()) | |
operator fun dec() = constructor(value.dec()) | |
operator fun unaryMinus() = constructor(value.unaryMinus()) | |
operator fun unaryPlus() = constructor(value.unaryPlus()) | |
operator fun times(other: Byte) = constructor(value.times(other)) | |
operator fun times(other: Double) = constructor(value.times(other).toLong()) | |
operator fun times(other: Float) = constructor(value.times(other).toLong()) | |
operator fun times(other: Int) = constructor(value.times(other)) | |
operator fun times(other: Long) = constructor(value.times(other)) | |
operator fun times(other: Short) = constructor(value.times(other)) | |
operator fun div(other: Byte) = constructor(value.div(other)) | |
operator fun div(other: Double) = constructor(value.div(other).toLong()) | |
operator fun div(other: Float) = constructor(value.div(other).toLong()) | |
operator fun div(other: Int) = constructor(value.div(other)) | |
operator fun div(other: Long) = constructor(value.div(other)) | |
operator fun div(other: Short) = constructor(value.div(other)) | |
operator fun rem(other: Byte) = constructor(value.rem(other)) | |
operator fun rem(other: Double) = constructor(value.rem(other).toLong()) | |
operator fun rem(other: Float) = constructor(value.rem(other).toLong()) | |
operator fun rem(other: Int) = constructor(value.rem(other)) | |
operator fun rem(other: Long) = constructor(value.rem(other)) | |
operator fun rem(other: Short) = constructor(value.rem(other)) | |
operator fun plus(other: T) = constructor(value.plus(other.value)) | |
operator fun minus(other: T) = constructor(value.minus(other.value)) | |
} | |
abstract class DoubleUnit<T: DoubleUnit<T>>( | |
value: Double, | |
private val constructor: (Double) -> T | |
): Unit<Double, T>(value) { | |
override fun compareTo(other: T) = value.compareTo(other.value) | |
operator fun inc() = constructor(value.inc()) | |
operator fun dec() = constructor(value.dec()) | |
operator fun unaryMinus() = constructor(value.unaryMinus()) | |
operator fun unaryPlus() = constructor(value.unaryPlus()) | |
operator fun times(other: Byte) = constructor(value.times(other)) | |
operator fun times(other: Double) = constructor(value.times(other)) | |
operator fun times(other: Float) = constructor(value.times(other)) | |
operator fun times(other: Int) = constructor(value.times(other)) | |
operator fun times(other: Long) = constructor(value.times(other)) | |
operator fun times(other: Short) = constructor(value.times(other)) | |
operator fun div(other: Byte) = constructor(value.div(other)) | |
operator fun div(other: Double) = constructor(value.div(other)) | |
operator fun div(other: Float) = constructor(value.div(other)) | |
operator fun div(other: Int) = constructor(value.div(other)) | |
operator fun div(other: Long) = constructor(value.div(other)) | |
operator fun div(other: Short) = constructor(value.div(other)) | |
operator fun rem(other: Byte) = constructor(value.rem(other)) | |
operator fun rem(other: Double) = constructor(value.rem(other)) | |
operator fun rem(other: Float) = constructor(value.rem(other)) | |
operator fun rem(other: Int) = constructor(value.rem(other)) | |
operator fun rem(other: Long) = constructor(value.rem(other)) | |
operator fun rem(other: Short) = constructor(value.rem(other)) | |
operator fun plus(other: T) = constructor(value.plus(other.value)) | |
operator fun minus(other: T) = constructor(value.minus(other.value)) | |
} | |
abstract class FloatUnit<T: FloatUnit<T>>( | |
value: Float, | |
private val constructor: (Float) -> T | |
): Unit<Float, T>(value) { | |
override fun compareTo(other: T) = value.compareTo(other.value) | |
operator fun inc() = constructor(value.inc()) | |
operator fun dec() = constructor(value.dec()) | |
operator fun unaryMinus() = constructor(value.unaryMinus()) | |
operator fun unaryPlus() = constructor(value.unaryPlus()) | |
operator fun times(other: Byte) = constructor(value.times(other)) | |
operator fun times(other: Double) = constructor(value.times(other).toFloat()) | |
operator fun times(other: Float) = constructor(value.times(other)) | |
operator fun times(other: Int) = constructor(value.times(other)) | |
operator fun times(other: Long) = constructor(value.times(other)) | |
operator fun times(other: Short) = constructor(value.times(other)) | |
operator fun div(other: Byte) = constructor(value.div(other)) | |
operator fun div(other: Double) = constructor(value.div(other).toFloat()) | |
operator fun div(other: Float) = constructor(value.div(other)) | |
operator fun div(other: Int) = constructor(value.div(other)) | |
operator fun div(other: Long) = constructor(value.div(other)) | |
operator fun div(other: Short) = constructor(value.div(other)) | |
operator fun rem(other: Byte) = constructor(value.rem(other)) | |
operator fun rem(other: Double) = constructor(value.rem(other).toFloat()) | |
operator fun rem(other: Float) = constructor(value.rem(other)) | |
operator fun rem(other: Int) = constructor(value.rem(other)) | |
operator fun rem(other: Long) = constructor(value.rem(other)) | |
operator fun rem(other: Short) = constructor(value.rem(other)) | |
operator fun plus(other: T) = constructor(value.plus(other.value)) | |
operator fun minus(other: T) = constructor(value.minus(other.value)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment