Skip to content

Commit

Permalink
Widen mutable.BitSet operations to accept any BitSet
Browse files Browse the repository at this point in the history
Add other optimizations for:
* mutable.BitSeta.addAll
* Fix implementation of mutable.BitSet.subsetOf
  • Loading branch information
joshlemer committed Feb 6, 2019
1 parent 596c123 commit c196be3
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 14 deletions.
111 changes: 100 additions & 11 deletions src/library/scala/collection/mutable/BitSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import scala.collection.immutable.Range
import BitSetOps.{LogWL, MaxSize}
import scala.annotation.implicitNotFound


/**
* A class for mutable bitsets.
*
Expand Down Expand Up @@ -106,34 +105,47 @@ class BitSet(protected[collection] final var elems: Array[Long])
* @param other the bitset to form the union with.
* @return the bitset itself.
*/
def |= (other: BitSet): this.type = {
def |= (other: collection.BitSet): this.type = {
ensureCapacity(other.nwords - 1)
for (i <- Range(0, other.nwords))
var i = 0
val othernwords = other.nwords
while (i < othernwords) {
elems(i) = elems(i) | other.word(i)
i += 1
}
this
}
/** Updates this bitset to the intersection with another bitset by performing a bitwise "and".
*
* @param other the bitset to form the intersection with.
* @return the bitset itself.
*/
def &= (other: BitSet): this.type = {
def &= (other: collection.BitSet): this.type = {
// Different from other operations: no need to ensure capacity because
// anything beyond the capacity is 0. Since we use other.word which is 0
// off the end, we also don't need to make sure we stay in bounds there.
for (i <- Range(0, nwords))
var i = 0
val thisnwords = nwords
while (i < thisnwords) {
elems(i) = elems(i) & other.word(i)
i += 1
}
this
}
/** Updates this bitset to the symmetric difference with another bitset by performing a bitwise "xor".
*
* @param other the bitset to form the symmetric difference with.
* @return the bitset itself.
*/
def ^= (other: BitSet): this.type = {
def ^= (other: collection.BitSet): this.type = {
ensureCapacity(other.nwords - 1)
for (i <- Range(0, other.nwords))
var i = 0
val othernwords = other.nwords
while (i < othernwords) {

elems(i) = elems(i) ^ other.word(i)
i += 1
}
this
}
/** Updates this bitset to the difference with another bitset by performing a bitwise "and-not".
Expand All @@ -142,17 +154,16 @@ class BitSet(protected[collection] final var elems: Array[Long])
* @return the bitset itself.
*/
def &~= (other: collection.BitSet): this.type = {
val words = Math.min(nwords, other.nwords)
var i = 0
while (i < words) {
val max = Math.min(nwords, other.nwords)
while (i < max) {
elems(i) = elems(i) & ~other.word(i)
i += 1
}
this
}

override def clone(): BitSet =
new BitSet(java.util.Arrays.copyOf(elems, elems.length))
override def clone(): BitSet = new BitSet(java.util.Arrays.copyOf(elems, elems.length))

def toImmutable: immutable.BitSet = immutable.BitSet.fromBitMask(elems)

Expand All @@ -172,6 +183,84 @@ class BitSet(protected[collection] final var elems: Array[Long])
override def zip[B](that: IterableOnce[B])(implicit @implicitNotFound(collection.BitSet.zipOrdMsg) ev: Ordering[(Int, B)]): SortedSet[(Int, B)] =
super.zip(that)

override def addAll(xs: IterableOnce[Int]): this.type = xs match {
case bs: collection.BitSet =>
this |= bs
case range: Range =>
if (range.nonEmpty) {
val start = range.min
if (start >= 0) {
val end = range.max
val endIdx = end >> LogWL
ensureCapacity(endIdx)

if (range.step == 1 || range.step == -1) {
val startIdx = start >> LogWL
val wordStart = startIdx * BitSetOps.WordLength
val wordMask = -1L << (start - wordStart)

if (endIdx > startIdx) {
elems(startIdx) |= wordMask
java.util.Arrays.fill(elems, startIdx + 1, endIdx, -1L)
elems(endIdx) |= -1L >>> (BitSetOps.WordLength - (end - endIdx * BitSetOps.WordLength) - 1)
} else elems(endIdx) |= (wordMask & (-1L >>> (BitSetOps.WordLength - (end - wordStart) - 1)))
} else super.addAll(range)
} else super.addAll(range)
}
this

case sorted: collection.SortedSet[Int] =>
// if `sorted` is using the regular Int ordering, ensure capacity for the largest
// element up front to avoid multiple resizing allocations
if (sorted.nonEmpty) {
val ord = sorted.ordering
if (ord eq Ordering.Int) {
ensureCapacity(sorted.lastKey >> LogWL)
} else if (ord eq Ordering.Int.reverse) {
ensureCapacity(sorted.firstKey >> LogWL)
}
val iter = sorted.iterator
while (iter.hasNext) {
addOne(iter.next())
}
}

this

case other =>
super.addAll(other)
}

override def subsetOf(that: collection.Set[Int]): Boolean = that match {
case bs: collection.BitSet =>
val thisnwords = this.nwords
val bsnwords = bs.nwords
val minWords = Math.min(thisnwords, bsnwords)

// if any bits are set to `1` in words out of range of `bs`, then this is not a subset. Start there
var i = bsnwords
while (i < thisnwords) {
if (word(i) != 0L) return false
i += 1
}

// the higher range of `this` is all `0`s, fall back to lower range
var j = 0
while (j < minWords) {
if ((word(j) & ~bs.word(j)) != 0L) return false
j += 1
}

true
case other =>
super.subsetOf(other)
}

override def subtractAll(xs: IterableOnce[Int]): this.type = xs match {
case bs: collection.BitSet => this &~= bs
case other => super.subtractAll(other)
}

protected[this] def writeReplace(): AnyRef = new BitSet.SerializationProxy(this)

override def diff(that: collection.Set[Int]): BitSet = that match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package scala.collection.mutable

import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra._
import java.util.concurrent.TimeUnit

import scala.collection.mutable

@BenchmarkMode(Array(Mode.AverageTime))
@Fork(1)
@Threads(1)
@Warmup(iterations = 6)
@Measurement(iterations = 6)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
class BitSetBenchmark {
@Param(Array("0", "3", "5", "10", "1000", "1000000"))
var size: Int = _

val bitSet = (1 to 1000).to(mutable.BitSet)

var bs: mutable.BitSet = _

var range: Range = _

val clones: Array[mutable.BitSet] = new Array(100)

@Setup(Level.Iteration) def initializeRange(): Unit = {
range = (10 to (10 + size))
}
@Setup(Level.Invocation) def initializeClones(): Unit = {
(0 until 100) foreach (i => clones(i) = bitSet.clone())
}

@Benchmark def addAll(bh: Blackhole): Unit = {
clones.foreach{ c =>
bh consume c.addAll(range)
}
}
}
6 changes: 6 additions & 0 deletions test/junit/scala/collection/mutable/BitSetTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,10 @@ class BitSetTest {
assert(BitSet().diff(a) == BitSet())
assert(BitSet().diff(BitSet()) == BitSet())
}

@Test def buildFromRange(): Unit = {
import scala.util.chaining._
assert((1 to 1000).to(BitSet) == BitSet().tap(bs => (1 to 1000).foreach(bs.addOne)))

}
}
35 changes: 32 additions & 3 deletions test/scalacheck/scala/collection/mutable/BitSetProperties.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package scala.collection.mutable

import org.scalacheck._
import org.scalacheck.Prop._
import Gen._
object BitSetProperties extends Properties("mutable.BitSet") {

object BitSetProperties extends Properties("mutable.BitSet") {
override def overrideParameters(p: Test.Parameters): Test.Parameters =
p.withMinSuccessfulTests(1000)
p.withMinSuccessfulTests(500)
.withInitialSeed(42L)

// the top of the range shouldn't be too high, else we may not get enough overlap
Expand All @@ -18,6 +17,17 @@ object BitSetProperties extends Properties("mutable.BitSet") {
)
)

/** the max number to include in generated BitSets */
val highestNum = 15000

implicit val nonNegativeRange: Arbitrary[Range] = Arbitrary(
for {
start <- chooseNum(0, highestNum)
end <- chooseNum(start, highestNum)
by <- oneOf(-1, 1, 5)
} yield start to end by by
)

property("diff") = forAll { (left: BitSet, right: BitSet) =>
(left.diff(right): Set[Int]) ?= left.to(HashSet).diff(right.to(HashSet))
}
Expand All @@ -39,4 +49,23 @@ object BitSetProperties extends Properties("mutable.BitSet") {
val (left, right) = bs.partition(p)
(left ?= bs.filter(p)) && (right ?= bs.filterNot(p))
}


property("addAll(Range)") = forAll{ (bs: BitSet, range: Range) =>
val bsClone1 = bs.clone()
val bsClone2 = bs.clone()
range.foreach(bsClone2.add)
bsClone1.addAll(range) ?= bsClone2
}

property("subsetOf(BitSet) equivalent to slow implementation") = forAll{ (left: BitSet, right: BitSet) =>
(Prop(left.subsetOf(right)) ==> left.forall(right)) &&
(Prop(left.forall(right)) ==> Prop(left.subsetOf(right)))
}

property("left subsetOf (left union right) && right subsetOf (left union right)") = forAll{ (left: BitSet, right: BitSet) =>
val leftUnionRight = left concat right
left.subsetOf(leftUnionRight) && right.subsetOf(leftUnionRight)
}

}

0 comments on commit c196be3

Please sign in to comment.