Skip to content

Instantly share code, notes, and snippets.

@swankjesse
Created May 28, 2020 03:00
Show Gist options
  • Save swankjesse/c164052a12aa260ebf0d2af8807e8378 to your computer and use it in GitHub Desktop.
Save swankjesse/c164052a12aa260ebf0d2af8807e8378 to your computer and use it in GitHub Desktop.
/*
* Copyright (C) 2020 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.concurrent
import java.util.concurrent.atomic.AtomicLong
/**
* Stores 63 boolean values at indexes 0 through 63. This permits atomic operations on the
* individual bits.
*/
internal class AtomicBitSet {
private val atomicLong = AtomicLong()
/** Returns a snapshot of the bit set. */
val rawBits: Long
get() = atomicLong.get()
/** Sets all bits to 0. */
fun clear() {
atomicLong.set(0L)
}
/** Returns the bit at [index]. */
operator fun get(index: Int): Boolean {
require(index in 0..63)
val bit = 1L shl index
val currentBits = atomicLong.get()
return currentBits[bit]
}
/** Sets the bit at [index] to [newValue] and returns its previous value. */
fun getAndSet(index: Int, newValue: Boolean): Boolean {
require(index in 0..63)
val bit = 1L shl index
while (true) {
val oldBits = atomicLong.get()
val newBits = oldBits.withBit(bit, newValue)
if (atomicLong.compareAndSet(oldBits, newBits)) return oldBits[bit]
}
}
/**
* Sets the bit at [index] to [newValue] but only if its current value is [expectedValue]. Returns
* true if the expectation was met and the value was changed.
*/
fun compareAndSet(index: Int, expectedValue: Boolean, newValue: Boolean): Boolean {
require(index in 0..63)
val bit = 1L shl index
while (true) {
val oldBits = atomicLong.get()
if (oldBits[bit] != expectedValue) return false // Expectation fail.
val newBits = oldBits.withBit(bit, newValue)
if (atomicLong.compareAndSet(oldBits, newBits)) return true
}
}
private operator fun Long.get(bit: Long) = this and bit != 0L
private fun Long.withBit(bit: Long, newValue: Boolean): Long {
return when (newValue) {
true -> this or bit
false -> this and bit.inv()
}
}
}
/*
* Copyright (C) 2020 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.concurrent
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.fail
import org.junit.Test
import java.lang.IllegalArgumentException
internal class AtomicBitSetTest {
@Test
fun getAndSet() {
val bits = AtomicBitSet()
assertThat(bits.rawBits).isEqualTo(0b0L)
assertThat(bits.getAndSet(0, true)).isFalse()
assertThat(bits.rawBits).isEqualTo(0b1L)
assertThat(bits.getAndSet(3, true)).isFalse()
assertThat(bits.rawBits).isEqualTo(0b1001L)
assertThat(bits.getAndSet(7, true)).isFalse()
assertThat(bits.rawBits).isEqualTo(0b10001001L)
assertThat(bits.getAndSet(3, true)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b10001001L)
assertThat(bits.getAndSet(3, false)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b10000001L)
assertThat(bits.getAndSet(0, false)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b10000000L)
assertThat(bits.getAndSet(7, true)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b10000000L)
assertThat(bits.getAndSet(7, false)).isTrue()
assertThat(bits.rawBits).isEqualTo(0L)
}
@Test
fun compareAndSet() {
val bits = AtomicBitSet()
assertThat(bits.compareAndSet(3, true, true)).isFalse()
assertThat(bits.rawBits).isEqualTo(0L)
assertThat(bits.compareAndSet(3, true, false)).isFalse()
assertThat(bits.rawBits).isEqualTo(0L)
assertThat(bits.compareAndSet(3, false, false)).isTrue()
assertThat(bits.rawBits).isEqualTo(0L)
assertThat(bits.compareAndSet(3, false, true)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b1000L)
assertThat(bits.compareAndSet(3, false, true)).isFalse()
assertThat(bits.rawBits).isEqualTo(0b1000L)
assertThat(bits.compareAndSet(3, false, false)).isFalse()
assertThat(bits.rawBits).isEqualTo(0b1000L)
assertThat(bits.compareAndSet(3, true, true)).isTrue()
assertThat(bits.rawBits).isEqualTo(0b1000L)
assertThat(bits.compareAndSet(3, true, false)).isTrue()
assertThat(bits.rawBits).isEqualTo(0L)
}
@Test
fun setBit63() {
val rawBiggestBit = 1L shl 63
val bits = AtomicBitSet()
bits.getAndSet(63, true)
assertThat(bits.rawBits).isEqualTo(rawBiggestBit)
assertThat(bits[63]).isEqualTo(true)
bits.getAndSet(63, false)
assertThat(bits.rawBits).isEqualTo(0L)
assertThat(bits[63]).isEqualTo(false)
}
@Test
fun clear() {
val bits = AtomicBitSet()
bits.getAndSet(0, true)
bits.getAndSet(1, true)
bits.getAndSet(63, true)
bits.clear()
assertThat(bits.rawBits).isEqualTo(0L)
}
@Test
fun bounds() {
val bits = AtomicBitSet()
try {
bits[-1]
fail()
} catch (expected: IllegalArgumentException) {
}
try {
bits[64]
fail()
} catch (expected: IllegalArgumentException) {
}
try {
bits.getAndSet(-1, true)
fail()
} catch (expected: IllegalArgumentException) {
}
try {
bits.getAndSet(64, true)
fail()
} catch (expected: IllegalArgumentException) {
}
try {
bits.compareAndSet(-1, false, true)
fail()
} catch (expected: IllegalArgumentException) {
}
try {
bits.compareAndSet(64, false, true)
fail()
} catch (expected: IllegalArgumentException) {
}
}
@Test
fun races() {
val bits = AtomicBitSet()
val bit0 = Thread {
toggle(bits, 0, 10_000)
}
val bit31 = Thread {
toggle(bits, 31, 10_000)
}
val bit63 = Thread {
toggle(bits, 63, 10_000)
}
bit0.start()
bit31.start()
bit63.start()
bit0.join()
bit31.join()
bit63.join()
assertThat(bits.rawBits).isEqualTo(0L)
}
private fun toggle(bits: AtomicBitSet, index: Int, count: Int) {
for (i in 0 until count) {
assertThat(bits.getAndSet(index, false)).isFalse()
assertThat(bits.compareAndSet(index, true, true)).isFalse()
assertThat(bits.compareAndSet(index, false, true)).isTrue()
assertThat(bits.getAndSet(index, true)).isTrue()
assertThat(bits.compareAndSet(index, false, false)).isFalse()
assertThat(bits.compareAndSet(index, true, true)).isTrue()
assertThat(bits.getAndSet(index, false)).isTrue()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment