Last active
October 1, 2024 07:58
-
-
Save yongjun21/66af2966f3569c08f8cd79b85e67a7f6 to your computer and use it in GitHub Desktop.
Custom data structure for efficient manipulation of sparse bitmask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
iterFilter, | |
iterFilterIndex, | |
ArrayConstructor, | |
ArrayLike, | |
} from "./base"; | |
import { OrderedSet } from "./OrderedCollections"; | |
export default class SparseBitmask { | |
private data = new OrderedSet<number>(); | |
public length: number; | |
public isRunEnds: boolean; | |
constructor(size: number, isRunEnds = true) { | |
this.length = size; | |
this.isRunEnds = isRunEnds; | |
} | |
// calling .reset() binds a new empty OrderedSet so as to not mutate the previous data object | |
// which might have been passed over to another SparseBitmask instance after calling .overwriteWith() | |
reset( | |
length: number = this.length, | |
callback?: (value: 1 | 0, index: number) => void | |
): this { | |
if (callback) { | |
const indices = this.isRunEnds | |
? fromRunEnds(this.data, length) | |
: this.data; | |
for (const index of indices) { | |
callback(0, index); | |
} | |
} | |
this.data.clear(); | |
this.length = length; | |
return this; | |
} | |
get(index: number): 1 | 0 { | |
return this.data.has(index) ? 1 : 0; | |
} | |
set(index: number, value: 1 | 0): void { | |
const curr = this.data.has(index); | |
if (value && !curr) { | |
this.data.add(index); | |
} else if (!value && curr) { | |
this.data.delete(index); | |
} | |
} | |
flip(index: number): void { | |
const curr = this.data.has(index); | |
if (curr) this.data.delete(index); | |
else this.data.add(index); | |
} | |
asRunEndIndices(): SparseBitmask { | |
if (this.isRunEnds) return this; | |
const mask = new SparseBitmask(this.length, true); | |
mask.data = OrderedSet.fromOrdered(toRunEnds(this.data, this.length)); | |
return mask; | |
} | |
asOneBitIndices(): SparseBitmask { | |
if (!this.isRunEnds) return this; | |
const mask = new SparseBitmask(this.length, false); | |
mask.data = OrderedSet.fromOrdered(fromRunEnds(this.data, this.length)); | |
return mask; | |
} | |
// symmetric diff, not the same as subtraction | |
diff(other: SparseBitmask): SparseBitmask { | |
const mask = new SparseBitmask(this.length, this.isRunEnds); | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
mask.data = OrderedSet.fromOrdered( | |
iterFilter(this.data, index => !sameOther.data.has(index)) | |
); | |
for (const index of sameOther.data) { | |
if (!this.data.has(index)) mask.data.add(index); | |
} | |
return mask; | |
} | |
// mutates and only works with symmetric diff | |
applyDiff(diff: SparseBitmask): void { | |
for (const index of diff.data) { | |
this.flip(index); | |
} | |
} | |
union(other: SparseBitmask): SparseBitmask { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
const [short, long] = [this.data, sameOther.data].sort( | |
(a, b) => a.size - b.size | |
); | |
const mask = new SparseBitmask(this.length, this.isRunEnds); | |
if (this.isRunEnds) { | |
mask.data = OrderedSet.fromOrdered(unionRunEnds(short, long)); | |
} else { | |
mask.data = OrderedSet.fromOrdered(long); | |
for (const index of short) { | |
mask.data.add(index); | |
} | |
} | |
return mask; | |
} | |
intersection(other: SparseBitmask): SparseBitmask { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
const [short, long] = [this.data, sameOther.data].sort( | |
(a, b) => a.size - b.size | |
); | |
const mask = new SparseBitmask(this.length, this.isRunEnds); | |
mask.data = OrderedSet.fromOrdered( | |
this.isRunEnds | |
? intersectionRunEnds(short, long) | |
: iterFilter(short, index => long.has(index)) | |
); | |
return mask; | |
} | |
subtraction(other: SparseBitmask): SparseBitmask { | |
const mask = new SparseBitmask(this.length, this.isRunEnds); | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
mask.data = OrderedSet.fromOrdered( | |
this.isRunEnds | |
? subtractionRunEnds(this.data, sameOther.data) | |
: iterFilter(this.data, index => !sameOther.data.has(index)) | |
); | |
return mask; | |
} | |
// mutates and optionally run side effects on the mutated indices | |
unionWithMutation( | |
other: SparseBitmask, | |
callback?: (value: 1 | 0, index: number) => void | |
): this { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
const indices = [ | |
...(this.isRunEnds | |
? unionRunEndsAsDiff(this.data, sameOther.data) | |
: iterFilter(sameOther.data, index => !this.data.has(index))), | |
]; | |
for (const index of indices) { | |
this.flip(index); | |
} | |
if (callback) { | |
const oneBitsIndices = this.isRunEnds | |
? fromRunEnds(indices, this.length) | |
: indices; | |
for (const index of oneBitsIndices) { | |
callback(1, index); | |
} | |
} | |
return this; | |
} | |
// mutates and optionally run side effects on the mutated indices | |
intersectionWithMutation( | |
other: SparseBitmask, | |
callback?: (value: 1 | 0, index: number) => void | |
): this { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
const indices = [ | |
...(this.isRunEnds | |
? intersectionRunEndsAsDiff(this.data, sameOther.data) | |
: iterFilter(this.data, index => !sameOther.data.has(index))), | |
]; | |
for (const index of indices) { | |
this.flip(index); | |
} | |
if (callback) { | |
const oneBitsIndices = this.isRunEnds | |
? fromRunEnds(indices, this.length) | |
: indices; | |
for (const index of oneBitsIndices) { | |
callback(1, index); | |
} | |
} | |
return this; | |
} | |
// mutates and optionally run side effects on the mutated indices | |
subtractionWithMutation( | |
other: SparseBitmask, | |
callback?: (value: 1 | 0, index: number) => void | |
): this { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
const indices = [ | |
...(this.isRunEnds | |
? subtractionRunEndsAsDiff(this.data, sameOther.data) | |
: iterFilter(this.data, index => sameOther.data.has(index))), | |
]; | |
for (const index of indices) { | |
this.flip(index); | |
} | |
if (callback) { | |
const oneBitsIndices = this.isRunEnds | |
? fromRunEnds(indices, this.length) | |
: indices; | |
for (const index of oneBitsIndices) { | |
callback(1, index); | |
} | |
} | |
return this; | |
} | |
// WARNING: this method shallowly copy data from another bitmask | |
// thereby exposing the other bitmask to unexpected mutations | |
// to avoid, call .clone() on the other bitmask or .reset() immediately after handover | |
overwriteWith( | |
other: SparseBitmask, | |
callback?: (value: 1 | 0, index: number) => void | |
): this { | |
const sameOther = this.isRunEnds | |
? other.asRunEndIndices() | |
: other.asOneBitIndices(); | |
if (callback) { | |
if (this.isRunEnds) { | |
diffRunEndsForEach(this.data, sameOther.data, this.length, callback); | |
} else { | |
for (const index of this.data) { | |
if (!sameOther.data.has(index)) callback(0, index); | |
} | |
for (const index of sameOther.data) { | |
if (!this.data.has(index)) callback(1, index); | |
} | |
} | |
} | |
this.data = sameOther.data; | |
return this; | |
} | |
forEach(callback: (value: 1 | 0, index: number) => void): void { | |
let index = 0; | |
for (const value of this) { | |
callback(value, index); | |
index += 1; | |
} | |
} | |
map<T>( | |
callback: (value: 1 | 0, index: number) => T, | |
Onto: ArrayConstructor<T> = Array | |
): ArrayLike<T> { | |
const arr = new Onto(this.length); | |
let index = 0; | |
for (const value of this) { | |
arr[index] = callback(value, index); | |
index += 1; | |
} | |
return arr; | |
} | |
clone(): SparseBitmask { | |
const mask = new SparseBitmask(this.length, this.isRunEnds); | |
mask.data = OrderedSet.fromOrdered(this.data); | |
return mask; | |
} | |
getIndicesCount(): number { | |
return this.data.size; | |
} | |
getIndices(asRunEnds?: boolean): Iterable<number> { | |
if (asRunEnds == null || this.isRunEnds === asRunEnds) { | |
return this.data; | |
} | |
return asRunEnds | |
? toRunEnds(this.data, this.length) | |
: fromRunEnds(this.data, this.length); | |
} | |
setIndices(indices: Iterable<number>, isRunEnds = true): this { | |
const sameIndices = | |
this.isRunEnds !== isRunEnds | |
? isRunEnds | |
? fromRunEnds(indices, this.length) | |
: toRunEnds(indices, this.length) | |
: indices; | |
this.data = OrderedSet.fromOrdered(sameIndices); | |
return this; | |
} | |
*[Symbol.iterator](): Iterator<1 | 0> { | |
let index = 0; | |
let curr: 1 | 0 = 0; | |
for (const i of this.data) { | |
while (index < i) { | |
yield curr; | |
index += 1; | |
} | |
yield (1 - curr) as 1 | 0; | |
index += 1; | |
if (this.isRunEnds) curr = (1 - curr) as 1 | 0; | |
} | |
while (index < this.length) { | |
yield curr; | |
index += 1; | |
} | |
} | |
static fromData<T>( | |
data: T[], | |
predicate: (v: T, i: number) => any = v => v, | |
convertToRunEnds = true | |
): SparseBitmask { | |
const mask = new SparseBitmask(data.length, convertToRunEnds); | |
let indices = iterFilterIndex(data, predicate); | |
if (convertToRunEnds) { | |
indices = toRunEnds(indices, data.length); | |
} | |
mask.data = OrderedSet.fromOrdered(indices); | |
return mask; | |
} | |
static fromRect( | |
topleft: number[], | |
bottomRight: number[], | |
width: number, | |
height: number, | |
convertToRunEnds = true | |
): SparseBitmask { | |
const mask = new SparseBitmask(width * height, convertToRunEnds); | |
let indices = fromRect( | |
topleft[0], | |
topleft[1], | |
bottomRight[0], | |
bottomRight[1], | |
width | |
); | |
if (!convertToRunEnds) { | |
indices = fromRunEnds(indices, width * height); | |
} | |
mask.data = OrderedSet.fromOrdered(indices); | |
return mask; | |
} | |
} | |
export function* fromRunEnds( | |
indices: Iterable<number>, | |
end: number | |
): Iterable<number> { | |
let index = -1; | |
let curr = 0; | |
for (const i of indices) { | |
if (curr) { | |
while (index < i) { | |
yield index; | |
index += 1; | |
} | |
} else { | |
index = i; | |
} | |
curr = 1 - curr; | |
} | |
if (curr) { | |
while (index < end) { | |
yield index; | |
index += 1; | |
} | |
} | |
} | |
export function* toRunEnds( | |
indices: Iterable<number>, | |
end: number | |
): Iterable<number> { | |
let lastIndex = -1; | |
for (const index of indices) { | |
if (lastIndex < 0) { | |
yield index; | |
} else if (index > lastIndex + 1) { | |
yield lastIndex + 1; | |
yield index; | |
} | |
lastIndex = index; | |
} | |
if (end > lastIndex + 1) { | |
yield lastIndex + 1; | |
} | |
} | |
function* fromRect( | |
xmin: number, | |
ymin: number, | |
xmax: number, | |
ymax: number, | |
width: number | |
): Iterable<number> { | |
if (xmin === 0 && xmax === width - 1) { | |
// full row | |
yield ymin * width; | |
yield ymax * width + width; | |
} else { | |
for (let j = ymin; j <= ymax; j += 1) { | |
yield j * width + xmin; | |
yield j * width + xmax + 1; | |
} | |
} | |
} | |
function* boolRunEnds( | |
curr: Iterable<number>, | |
next: Iterable<number>, | |
a: number, | |
b: number | |
): Iterable<number> { | |
/* | |
0: currState, 1: !currState, 2: nextState, 3: !nextState | |
a: condition to yield nextIndex (0 or 1) | |
b: condition to yield currIndex (2 or 3) | |
*/ | |
const state = [false, true, false, true]; | |
const nextIter = next[Symbol.iterator](); | |
let nextIndex = nextIter.next(); | |
for (const currIndex of curr) { | |
while (!nextIndex.done && nextIndex.value < currIndex) { | |
if (state[a]) yield nextIndex.value; | |
nextIndex = nextIter.next(); | |
// flip next state | |
state[2] = !state[2]; | |
state[3] = !state[3]; | |
} | |
if (nextIndex.done && !state[b]) return; | |
if (nextIndex.done || nextIndex.value > currIndex) { | |
if (state[b]) yield currIndex; | |
// flip curr state | |
state[0] = !state[0]; | |
state[1] = !state[1]; | |
} else { | |
if (state[a] === state[b]) yield currIndex; | |
nextIndex = nextIter.next(); | |
// flip curr & next state | |
state[0] = !state[0]; | |
state[1] = !state[1]; | |
state[2] = !state[2]; | |
state[3] = !state[3]; | |
} | |
} | |
if (state[a]) { | |
while (!nextIndex.done) { | |
yield nextIndex.value; | |
nextIndex = nextIter.next(); | |
} | |
} | |
} | |
export function unionRunEnds( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 1, 3); | |
} | |
export function intersectionRunEnds( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 0, 2); | |
} | |
export function subtractionRunEnds( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 0, 3); | |
} | |
function unionRunEndsAsDiff( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 1, 2); | |
} | |
function intersectionRunEndsAsDiff( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 0, 3); | |
} | |
function subtractionRunEndsAsDiff( | |
curr: Iterable<number>, | |
next: Iterable<number> | |
): Iterable<number> { | |
return boolRunEnds(curr, next, 0, 2); | |
} | |
function diffRunEndsForEach( | |
curr: Iterable<number>, | |
next: Iterable<number>, | |
length: number, | |
callback: (value: 1 | 0, index: number) => void | |
): void { | |
let currState = 0; | |
let nextState = 0; | |
const nextIter = next[Symbol.iterator](); | |
let nextIndex = nextIter.next(); | |
let index = 0; | |
const applyCallbackUntil = (runEnd: number) => { | |
if (currState !== nextState) { | |
while (index < runEnd) { | |
callback(nextState as 1 | 0, index); | |
index += 1; | |
} | |
} else { | |
index = runEnd; | |
} | |
}; | |
for (const currIndex of curr) { | |
while (!nextIndex.done && nextIndex.value < currIndex) { | |
applyCallbackUntil(nextIndex.value); | |
nextIndex = nextIter.next(); | |
nextState = 1 - nextState; | |
} | |
applyCallbackUntil(currIndex); | |
if (!nextIndex.done && nextIndex.value <= currIndex) { | |
nextIndex = nextIter.next(); | |
nextState = 1 - nextState; | |
} | |
currState = 1 - currState; | |
} | |
while (!nextIndex.done) { | |
applyCallbackUntil(nextIndex.value); | |
nextIndex = nextIter.next(); | |
nextState = 1 - nextState; | |
} | |
applyCallbackUntil(length); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment