Skip to content

Instantly share code, notes, and snippets.

@ggggggggg
Last active August 29, 2015 14:05
Show Gist options
  • Save ggggggggg/46a7e57ccff05d3a1558 to your computer and use it in GitHub Desktop.
Save ggggggggg/46a7e57ccff05d3a1558 to your computer and use it in GitHub Desktop.
import Base: isempty, copy, eltype, sizehint, push!, union!, pop!, delete!, ==, <, <=, length, start, next, done,
setdiff, symdiff, symdiff!, setdiff!, show, first, last, intersect, intersect!, symdiff, symdiff!
using Base.Test
type IntSetBitVector
b::BitVector # n is in IntSet if s[n+1]==true
fill1s::Bool
IntSetBitVector() = new(falses(256), false)
end
IntSetBitVector(itr) = (s=IntSetBitVector(); for a in itr; push!(s,a); end; s)
IntSetBitVector(itr...) = (s=IntSetBitVector(); for a in itr; push!(s,a); end; s)
similar(s::IntSetBitVector) = IntSetBitVector()
copy(s::IntSetBitVector) = union!(IntSetBitVector(), s)
eltype(s::IntSetBitVector) = Int64
function sizehint(s::IntSetBitVector, top::Integer)
if top >= length(s.b)
lim = iceil(top/64)*64
oldsize = length(s.b)
if oldsize < lim
assert(lim%64==0) # this is assumed to be true elsewhere
resize!(s.b, lim)
s.b[oldsize+1:end] = s.fill1s
end
end
s
end
function push!(s::IntSetBitVector, n::Integer)
if n >= length(s.b)
if s.fill1s
return s
else
lim = int(n + div(n,2))
sizehint(s, lim)
end
elseif n < 0
throw(ArgumentError("IntSetBitVector elements cannot be negative"))
end
Base.unsafe_bitsetindex!(s.b.chunks, true, n+1)
s
end
function union!(s::IntSetBitVector, ns)
for n in ns
push!(s, n)
end
s
end
function pop!(s::IntSetBitVector, n::Integer, default)
if n >= length(s.b)
if s.fill1s
lim = int(n + div(n,2))
sizehint(s, lim)
else
return default
end
end
if s.b[n+1]
s.b[n+1] = false
return n
else
return default
end
end
function pop!(s::IntSetBitVector, n::Integer)
if pop!(s, n, n+1) == n+1
throw(KeyError(n))
end
n
end
# TODO: what should happen when fill1s == true?
pop!(s::IntSetBitVector) = pop!(s, last(s))
function delete!(s::IntSetBitVector, n::Integer)
pop!(s, n, n)
s
end
function setdiff!(s::IntSetBitVector, ns)
for n in ns
delete!(s, n)
end
s
end
setdiff(a::IntSetBitVector, b::IntSetBitVector) = setdiff!(copy(a),b)
symdiff(s1::IntSetBitVector, s2::IntSetBitVector) =
(length(s1.b)>= length(s1.b)? symdiff!(copy(s1), s2) : symdiff!(copy(s2), s1))
function empty!(s::IntSetBitVector)
fill!(b,false)
s.fill1s=false
s
end
function symdiff!(s::IntSetBitVector, n::Integer)
if n >= length(s.b)
lim = int(n + dim(n,2))
sizehint(s, lim)
elseif n < 0
throw(ArgumentError("IntSetBitVector elements cannot be negative"))
end
s.b[n+1] = !s.b[n+1]
s
end
function symdiff!(s::IntSetBitVector, ns)
for n in ns
symdiff!(s, n)
end
s
end
function copy!(to::IntSetBitVector, from::IntSetBitVector)
empty!(to)
union!(to, from)
end
function in(n::Integer, s::IntSetBitVector)
if n >= length(s.b)
# max IntSetBitVector length is typemax(Int), so highest possible element is
# typemax(Int)-1
s.fill1s && n >= 0 && n < typemax(Int)
elseif n < 0
false
else
s.b[n+1]
end
end
start(s::IntSetBitVector) = int64(0)
done(s::IntSetBitVector, i) = (!s.fill1s && next(s,i)[1] >= length(s.b)) || i == typemax(Int)
function next(s::IntSetBitVector, i)
if i >= length(s.b)
n = int64(i)
else
n = findnext(s.b,i+1)
n = n == 0 ? length(s.b) : n-1
end
(n, n+1)
end
isempty(s::IntSetBitVector) = !s.fill1s && isempty(s.b)
function first(s::IntSetBitVector)
n = next(s,0)[1]
if n >= length(s.b)
error("set must be non-empty")
end
n
end
shift!(s::IntSetBitVector) = pop!(s, first(s))
function last(s::IntSetBitVector)
if !s.fill1s
for i = length(s.b.chunks):-1:1
if s.b.chunks[i] != 0
for j = i*64:-1:(i-1)*64+1
s.b[j] && (return j-1)
end
error("this shouldn't be possible")
end
end
end
error("set has no last element")
end
length(s::IntSetBitVector) = sum(s.b) + (s.fill1s ? typemax(Int) - length(s.b) : 0)
function show(io::IO, s::IntSetBitVector)
print(io, "IntSetBitVector([")
first = true
for n in s
if n > length(s.b)
break
end
if !first
print(io, ", ")
end
print(io, n)
first = false
end
if s.fill1s
print(io, ", ..., ", typemax(Int)-1, ")")
else
print(io, "])")
end
end
# Math functions
function union!(s::IntSetBitVector, s2::IntSetBitVector)
if length(s2.b) > length(s.b)
sizehint(s, length(s2.b))
end
lim = length(s2.b)
for i = 1:div(lim, 64)
s.b.chunks[i] |= s2.b.chunks[i]
end
s2.fill1s && (s.b[lim+1:end] = true)
s.fill1s |= s2.fill1s
s
end
union(s1::IntSetBitVector) = copy(s1)
union(s1::IntSetBitVector, s2::IntSetBitVector) = (length(s1.b)>= length(s2.b) ? union!(copy(s1), s2) : union!(copy(s2), s1))
union(s1::IntSetBitVector, ss::IntSetBitVector...) = union(s1, union(ss...))
function intersect!(s::IntSetBitVector, s2::IntSetBitVector)
if length(s2.b) > length(s.b)
sizehint(s, length(s2.b))
end
lim = length(s2.b)
for i = 1:div(lim, 64)
s.b.chunks[i] &= s2.b.chunks[i]
end
s2.fill1s || (s.b[lim+1:end] = false)
s.fill1s &= s2.fill1s
s
end
intersect(s1::IntSetBitVector) = copy(s1)
intersect(s1::IntSetBitVector, s2::IntSetBitVector) =
(length(s1.b)>= length(s2.b) ? intersect!(copy(s1), s2) : intersect!(copy(s2), s1))
intersect(s1::IntSetBitVector, ss::IntSetBitVector...) = intersect(s1, intersect(ss...))
function complement!(s::IntSetBitVector)
s.b = !s.b
s.fill1s = !s.fill1s
s
end
complement(s::IntSetBitVector) = complement!(copy(s))
function symdiff!(s::IntSetBitVector, s2::IntSetBitVector)
length(s2.b) > length(s.b) && sizehint(s, length(s2.b))
for i = 1:div(length(s2.b),64)
s.b.chunks[i] $= s2.b.chunks[i]
end
if s2.fill1s
for i = length(s2.b.chunks)+1:length(s.b.chunks)
s.b.chunks[i] = ~s.b.chunks[1]
end
end
s.fill1s $= s2.fill1s
s
end
function ==(s1::IntSetBitVector, s2::IntSetBitVector)
s1.fill1s != s2.fill1s && (return false)
lim1 = length(s1.b)
lim2 = length(s2.b)
jointlim = min(lim1,lim2)
for i = 1:div(jointlim, 64)
s1.b.chunks[i] == s2.b.chunks[i] || (return false)
end
if lim1 > lim2
findnext(s1.b, !s2.fill1s, lim2+1) == 0 || (return false)
elseif lim2 > lim1
findnext(s2.b, !s1.fill1s, lim1+1) == 0 || (return false)
end
true
end
issubset(a::IntSetBitVector, b::IntSetBitVector) = isequal(a, intersect(a,b))
<(a::IntSetBitVector, b::IntSetBitVector) = (a<=b) && !isequal(a,b)
<=(a::IntSetBitVector, b::IntSetBitVector) = issubset(a, b)
@test setdiff(IntSetBitVector(1, 2, 3, 4), IntSetBitVector(2, 4, 5, 6)) == IntSetBitVector(1, 3)
@test setdiff(Set(1, 2, 3, 4), Set(2, 4, 5, 6)) == Set(1, 3)
@test symdiff(IntSetBitVector(1, 2, 3, 4), IntSetBitVector(2, 4, 5, 6)) == IntSetBitVector(1, 3, 5, 6)
@test symdiff(IntSetBitVector(2, 4, 5, 6), IntSetBitVector(1, 2, 3, 4)) == IntSetBitVector(1, 3, 5, 6)
@test symdiff(Set(1, 2, 3, 4), Set(2, 4, 5, 6)) == Set(1, 3, 5, 6)
s1 = Set(1, 2, 3, 4)
setdiff!(s1, Set(2, 4, 5, 6))
@test s1 == Set(1, 3)
s2 = IntSetBitVector(1, 2, 3, 4)
setdiff!(s2, IntSetBitVector(2, 4, 5, 6))
@test s2 == IntSetBitVector(1, 3)
# issue #7851
@test_throws ArgumentError IntSetBitVector(-1)
@test !(-1 in IntSetBitVector(0:10))
## IntSetBitVector
contents = [0,1,10,20,200,300,1000,10000,10002]
s = IntSetBitVector(contents...)
@test [j for j in s] == contents
@test last(s) == 10002
@test first(s) == 0
@test length(s) == 9
@test pop!(s) == 10002
@test length(s) == 8
@test shift!(s) == 0
@test length(s) == 7
@test !in(0,s)
@test !in(10002,s)
@test in(10000,s)
@test_throws ErrorException first(IntSetBitVector())
@test_throws ErrorException last(IntSetBitVector())
t = copy(s)
sizehint(t, 20000) #check that hash does not depend on size of internal Array{Uint32, 1}
union(s1::IntSet, s2::IntSet) = (s1.limit >= s2.limit ? union!(copy(s1), s2) : union!(copy(s2), s1))
a = IntSetBitVector(rand(1:10000000,1000));
b = IntSetBitVector(rand(1:10000000,1000));
println("IntSetBitVector")
@time [a==b for j = 1:1000];
@time [union(a,b) for j=1:100];
@time [for i in a end for j=1:1000];
@time [intersect(a,b) for j=1:100];
@time [symdiff(a,b) for j = 1:100];
c = IntSet(collect(a));
d = IntSet(collect(b));
println("IntSet")
@time [c==d for j = 1:1000];
@time [union(c,d) for j=1:100];
@time [for i in c end for j=1:1000];
@time [intersect(c,d) for j=1:100];
@time [symdiff(c,d) for j = 1:100];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment