Last active
March 15, 2021 22:26
-
-
Save ChristopherDavenport/3d4d84ea9ab04df3ecf343e089520da1 to your computer and use it in GitHub Desktop.
Scala Translation of a ThreadSafeBitSet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Credit to Hollow: https://github.com/Netflix/hollow/blob/b2890aab063b530b1cffe1af7b21d1358024c6e7/hollow/src/main/java/com/netflix/hollow/core/memory/ThreadSafeBitSet.java | |
| import java.util.concurrent.atomic.AtomicLongArray | |
| import java.util.concurrent.atomic.AtomicReference | |
| import scala.util.control.Breaks | |
| import scala.util.hashing.MurmurHash3 | |
| import scala.collection.BitSet | |
| import scala.collection.mutable | |
| class ThreadSafeBitSet private ( | |
| private final val numLongsPerSegment: Int, | |
| private final val log2SegmentSize: Int, | |
| private final val segmentMask: Int, | |
| private final val segments: AtomicReference[ThreadSafeBitSet.ThreadSafeBitSegments] | |
| ) { // TODO extends scala.collection.mutable.BitSet | |
| /* | |
| * -------------------- | |
| * Modifications | |
| * -------------------- | |
| */ | |
| def set(position: Long): Unit = { | |
| val segmentPosition = position >>> log2SegmentSize // which segment -- div by num bits per segment | |
| val longPosition = (position >>> 6) & segmentMask // which long in the segment -- remainder of div by num bits per segment | |
| val bitPosition = position & 0x3F // which bit in the long -- remainder of div by num bits in long (64) -- positive bits | |
| val segment = getSegment(segmentPosition.toInt) | |
| val mask = 1L << bitPosition | |
| var retry = true | |
| while (retry){ | |
| val currentLongValue = segment.get(longPosition.toInt) | |
| val newLongValue = currentLongValue | mask | |
| if (segment.compareAndSet(longPosition.toInt, currentLongValue, newLongValue)){ | |
| retry = false | |
| } | |
| } | |
| } | |
| def clear(position: Long): Unit = { | |
| val segmentPosition = position >>> log2SegmentSize // which segment -- div by num bits per segment | |
| val longPosition = (position >>> 6) & segmentMask // which long in the segment -- remainder of div by num bits per segment | |
| val bitPosition = position & 0x3F /// which bit in the long -- remainder of div by num bits in long (64) | |
| val segment = getSegment(segmentPosition.toInt) | |
| val mask = ~(1L << bitPosition) | |
| var retry = true | |
| while (retry){ | |
| val currentLongValue = segment.get(longPosition.toInt) | |
| val newLongValue = currentLongValue & mask | |
| if (segment.compareAndSet(longPosition.toInt, currentLongValue, newLongValue)){ | |
| retry = false | |
| } | |
| } | |
| } | |
| def get(position: Long): Boolean = { | |
| val segmentPosition = position >>> log2SegmentSize // which segment -- div by num bits per segment | |
| val longPosition = (position >>> 6) & segmentMask // which long in the segment -- remainder of div by num bits per segment | |
| val bitPosition = position & 0x3F /// which bit in the long -- remainder of div by num bits in long (64) | |
| val segment = getSegment(segmentPosition.toInt) | |
| val mask = 1L << bitPosition | |
| (segment.get(longPosition.toInt) & mask) != 0 | |
| } | |
| /** | |
| * Clear all bits to 0. | |
| */ | |
| def clearAll: Unit = { | |
| val visibleSegments = segments.get | |
| for { | |
| i <- 0 until visibleSegments.numSegments | |
| segment = visibleSegments.getSegment(i) | |
| j <- 0 until segment.length() | |
| } { | |
| segment.set(j, 0) | |
| } | |
| } | |
| /* | |
| * -------------------- | |
| * Informational | |
| * -------------------- | |
| */ | |
| def maxSetBit: Long = { | |
| val breaks = new Breaks | |
| val viewableSegments = segments.get() | |
| var bitPosition = -1 | |
| breaks.breakable{ | |
| for { | |
| segmentIdx <- (viewableSegments.numSegments - 1) to 0 by -1 | |
| segment = viewableSegments.getSegment(segmentIdx) | |
| longIdx <- (segment.length() - 1) to 0 by -1 | |
| } { | |
| val l = segment.get(longIdx) | |
| if (l != 0) { | |
| bitPosition = (segmentIdx << log2SegmentSize) + (longIdx * 64) + (63 - java.lang.Long.numberOfLeadingZeros(l)) | |
| breaks.break() | |
| } | |
| } | |
| } | |
| bitPosition | |
| } | |
| def nextSetBit(fromIndex: Long): Long = { | |
| require(fromIndex >= 0, s"fromIndex must be >= 0: got $fromIndex") | |
| var segmentPosition = fromIndex >>> log2SegmentSize | |
| val viewableSegments = segments.get() | |
| if (segmentPosition >= viewableSegments.numSegments) -1 | |
| else { | |
| var longPosition = (fromIndex >>> 6) & segmentMask // which long in the segment -- remainder of div by num bits per segment | |
| var bitPosition = fromIndex & 0x3F // which bit in the long -- remainder of div by num bits in long (64) | |
| var segment = viewableSegments.getSegment(segmentPosition.toInt) | |
| var word = segment.get(longPosition.toInt) & (0xffffffffffffffffL << bitPosition) | |
| var response = -1L | |
| var loop = true | |
| while (loop) { | |
| if (word != 0) { | |
| response = (segmentPosition << (log2SegmentSize)) + (longPosition << 6) + java.lang.Long.numberOfTrailingZeros(word) | |
| loop = false | |
| } else { | |
| longPosition += 1 | |
| if (longPosition > segmentMask) { | |
| segmentPosition += 1 | |
| if (segmentPosition >= viewableSegments.numSegments) { | |
| loop = false | |
| // No bits set, return - | |
| } else { | |
| segment = viewableSegments.getSegment(segmentPosition.toInt) | |
| longPosition = 0 | |
| word = segment.get(longPosition.toInt) | |
| } | |
| } else { | |
| word = segment.get(longPosition.toInt) | |
| } | |
| } | |
| } | |
| response | |
| } | |
| } | |
| /** | |
| * The numbers of bits which are set in this bit set. | |
| **/ | |
| def cardinality: Int = { | |
| val viewableSegments = segments.get() | |
| var numSetBits = 0 | |
| for { | |
| i <- 0 until viewableSegments.numSegments | |
| segment = viewableSegments.getSegment(i) | |
| j <- 0 until segment.length() | |
| } { | |
| numSetBits += java.lang.Long.bitCount(segment.get(j)) | |
| } | |
| numSetBits | |
| } | |
| /** | |
| * The number of bits which are currently specified by this bit set. This | |
| * is the maximum number which you might need to iterate if you were to | |
| * iterate over all the bits in this set. | |
| */ | |
| def currentCapacity: Int = | |
| segments.get().numSegments * (1 << log2SegmentSize) | |
| def eqv(other: ThreadSafeBitSet): Boolean = { | |
| require(other.log2SegmentSize == log2SegmentSize, "Segment sizes must be the same") | |
| val thisSegments = segments.get | |
| val otherSegments = other.segments.get | |
| var allEqual = true | |
| val breaks = new Breaks | |
| breaks.breakable{ | |
| // Check All of That Equal to All of This | |
| for { | |
| i <- 0 until thisSegments.numSegments | |
| thisArray = thisSegments.getSegment(i) | |
| otherArray = { | |
| if (i < otherSegments.numSegments) Some(otherSegments.getSegment(i)) | |
| else None | |
| } | |
| j <- 0 until thisArray.length() | |
| } { | |
| val thisLong = thisArray.get(j) | |
| val otherLong = otherArray.map(_.get(j)).getOrElse(0L) | |
| if (thisLong != otherLong) { | |
| allEqual = false | |
| breaks.break() | |
| } | |
| } | |
| // Check that anything left in that is equal to 0 | |
| for { | |
| i <- thisSegments.numSegments until otherSegments.numSegments | |
| otherArray = otherSegments.getSegment(i) | |
| j <- 0 until otherArray.length | |
| } { | |
| val l = otherArray.get(j) | |
| if (l != 0) { | |
| allEqual = false | |
| breaks.break() | |
| } | |
| } | |
| } | |
| allEqual | |
| } | |
| /** | |
| * Return a new bit set which contains all bits which are contained in this bit set, and which are NOT contained in the `other` bit set. | |
| * | |
| * In other words, return a new bit set, which is a bitwise and with the bitwise not of the other bit set. | |
| * | |
| */ | |
| def andNot(other: ThreadSafeBitSet): ThreadSafeBitSet = { | |
| require(other.log2SegmentSize == log2SegmentSize, "Segment sizes must be the same") | |
| val thisSegments = segments.get() | |
| val otherSegments = other.segments.get() | |
| val newSegments = ThreadSafeBitSet.ThreadSafeBitSegments(thisSegments.numSegments, numLongsPerSegment) | |
| for { | |
| i <- 0 until thisSegments.numSegments | |
| thisArray = thisSegments.getSegment(i) | |
| otherArray = { | |
| if (i < otherSegments.numSegments) Some(otherSegments.getSegment(i)) | |
| else None | |
| } | |
| newArray = newSegments.getSegment(i) | |
| j <- 0 until thisArray.length() | |
| } { | |
| val thisLong = thisArray.get(j) | |
| val otherLong = otherArray.fold(0L)(a => a.get(j)) | |
| newArray.set(j, thisLong & ~ otherLong) | |
| } | |
| val andNot = ThreadSafeBitSet(log2SegmentSize) | |
| andNot.segments.set(newSegments) | |
| andNot | |
| } | |
| // Get the segment at `segmentIndex`. If this segment does not yet exist, create it. | |
| def getSegment(segmentIndex: Int): AtomicLongArray = { | |
| var visibleSegments = segments.get | |
| while(visibleSegments.numSegments <= segmentIndex){ | |
| // Thread safety: newVisibleSegments contains all of the segments from the currently visible segments, plus extra. | |
| // all of the segments in the currently visible segments are canonical and will not change. | |
| val newVisibleSegments = ThreadSafeBitSet.ThreadSafeBitSegments(visibleSegments, segmentIndex + 1, numLongsPerSegment) | |
| // because we are using a compareAndSet, if this thread "wins the race" and successfully sets this variable, then the segments | |
| // which are newly defined in newVisibleSegments become canonical. | |
| if (segments.compareAndSet(visibleSegments, newVisibleSegments)) { | |
| visibleSegments = newVisibleSegments | |
| } else { | |
| // If we "lose the race" and are growing the ThreadSafeBitSet segments larger, | |
| // then we will gather the new canonical sets from the update which we missed on the next iteration of this loop. | |
| // Newly defined segments in newVisibleSegments will be discarded, they do not get to become canonical. | |
| visibleSegments = segments.get(); | |
| } | |
| } | |
| visibleSegments.getSegment(segmentIndex) | |
| } | |
| override def equals(obj: Any): Boolean = obj match { | |
| case that: ThreadSafeBitSet => eqv(that) | |
| case _ => false | |
| } | |
| override def hashCode(): Int = { | |
| 31 * log2SegmentSize + | |
| MurmurHash3.arrayHash(segments.get().segments) | |
| } | |
| // Only works if Int Capable values are there | |
| def toMutableBitSet: scala.collection.mutable.BitSet = { | |
| val resultSet = scala.collection.mutable.BitSet.empty | |
| var ordinal = nextSetBit(0) | |
| while(ordinal != -1){ | |
| resultSet.add(ordinal.toInt) | |
| ordinal = nextSetBit(ordinal + 1) | |
| } | |
| resultSet | |
| } | |
| // This is a full copy for this. | |
| override def toString(): String = { | |
| val longs = new scala.collection.mutable.ListBuffer[Long]() | |
| var ordinal = nextSetBit(0) | |
| while(ordinal != -1L){ | |
| longs.addOne(ordinal) | |
| ordinal = nextSetBit(ordinal + 1) | |
| } | |
| "ThreadSafeBitSet(" ++ longs.mkString(",") + ")" | |
| } | |
| } | |
| object ThreadSafeBitSet { | |
| val DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS = 14 | |
| // TODO: Overloads | |
| // def apply(): ThreadSafeBitSet = apply(DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS) | |
| // def apply(log2SegmentSize: Int): ThreadSafeBitSet = apply(log2SegmentSize, 0) | |
| def apply( | |
| log2SegmentSizeInBits: Int = DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS, | |
| numBitsToPreallocate: Long = 0L | |
| ): ThreadSafeBitSet = { | |
| require(log2SegmentSizeInBits > 6, "Cannot specify fewer than 64 bits in each segment!") | |
| val log2SegmentSize = log2SegmentSizeInBits | |
| val numLongsPerSegment = (1 << (log2SegmentSizeInBits - 6)) | |
| val segmentMask = numLongsPerSegment - 1 | |
| val numBitsPerSegment = numLongsPerSegment * 64 | |
| val numSegmentsToPreallocate = | |
| if (numBitsToPreallocate == 0) 1 | |
| else ((numBitsToPreallocate - 1) / numBitsPerSegment) + 1 | |
| val segments = new AtomicReference[ThreadSafeBitSegments]() | |
| segments.set(ThreadSafeBitSegments(numSegmentsToPreallocate.toInt, numLongsPerSegment)) | |
| new ThreadSafeBitSet(numLongsPerSegment, log2SegmentSize, segmentMask, segments) | |
| } | |
| def fromBitSet( | |
| bitSet: BitSet, | |
| log2SegmentSize: Int = DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS | |
| ): ThreadSafeBitSet = { | |
| val tsb = apply(log2SegmentSize, bitSet.size) | |
| bitSet.foreach(i => | |
| tsb.set(i) | |
| ) | |
| tsb | |
| } | |
| def orAll(bitSets: ThreadSafeBitSet*): ThreadSafeBitSet = { | |
| ??? | |
| } | |
| private class ThreadSafeBitSegments private (private[ThreadSafeBitSet] final val segments: Array[AtomicLongArray]){ | |
| def numSegments = segments.length | |
| def getSegment(index: Int) = segments(index) | |
| } | |
| private object ThreadSafeBitSegments { | |
| def apply(numSegments: Int, segmentLength: Int) = { | |
| val segments = new Array[AtomicLongArray](numSegments) | |
| for(i <- 0 until numSegments) { | |
| segments.update(i, new AtomicLongArray(segmentLength)) | |
| } | |
| new ThreadSafeBitSegments(segments) | |
| } | |
| def apply(copyFrom: ThreadSafeBitSegments, numSegments: Int, segmentLength: Int) = { | |
| val segments = new Array[AtomicLongArray](numSegments) | |
| for(i <- 0 until numSegments) { | |
| val set = if (i < copyFrom.numSegments) copyFrom.getSegment(i) else new AtomicLongArray(segmentLength) | |
| segments.update(i, set) | |
| } | |
| new ThreadSafeBitSegments(segments) | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment