#![deny(missing_docs)]
extern crate atom;
#[cfg(test)]
extern crate rand;
#[cfg(feature = "parallel")]
extern crate rayon;
mod atomic;
mod iter;
mod ops;
mod util;
pub use atomic::AtomicBitSet;
pub use iter::{BitIter, DrainBitIter};
#[cfg(feature = "parallel")]
pub use iter::{BitParIter, BitProducer};
pub use ops::{BitSetAll, BitSetAnd, BitSetNot, BitSetOr, BitSetXor};
use util::*;
#[derive(Clone, Debug, Default)]
pub struct BitSet {
layer3: usize,
layer2: Vec<usize>,
layer1: Vec<usize>,
layer0: Vec<usize>,
}
impl BitSet {
pub fn new() -> BitSet {
Default::default()
}
#[inline]
fn valid_range(max: Index) {
if (MAX_EID as u32) < max {
panic!("Expected index to be less then {}, found {}", MAX_EID, max);
}
}
pub fn with_capacity(max: Index) -> BitSet {
Self::valid_range(max);
let mut value = BitSet::new();
value.extend(max);
value
}
#[inline(never)]
fn extend(&mut self, id: Index) {
Self::valid_range(id);
let (p0, p1, p2) = offsets(id);
Self::fill_up(&mut self.layer2, p2);
Self::fill_up(&mut self.layer1, p1);
Self::fill_up(&mut self.layer0, p0);
}
fn fill_up(vec: &mut Vec<usize>, upper_index: usize) {
if vec.len() <= upper_index {
vec.resize(upper_index + 1, 0);
}
}
#[inline(never)]
fn add_slow(&mut self, id: Index) {
let (_, p1, p2) = offsets(id);
self.layer1[p1] |= id.mask(SHIFT1);
self.layer2[p2] |= id.mask(SHIFT2);
self.layer3 |= id.mask(SHIFT3);
}
#[inline]
pub fn add(&mut self, id: Index) -> bool {
let (p0, mask) = (id.offset(SHIFT1), id.mask(SHIFT0));
if p0 >= self.layer0.len() {
self.extend(id);
}
if self.layer0[p0] & mask != 0 {
return true;
}
let old = self.layer0[p0];
self.layer0[p0] |= mask;
if old == 0 {
self.add_slow(id);
}
false
}
fn layer_mut(&mut self, level: usize, idx: usize) -> &mut usize {
match level {
0 => {
Self::fill_up(&mut self.layer0, idx);
&mut self.layer0[idx]
}
1 => {
Self::fill_up(&mut self.layer1, idx);
&mut self.layer1[idx]
}
2 => {
Self::fill_up(&mut self.layer2, idx);
&mut self.layer2[idx]
}
3 => &mut self.layer3,
_ => panic!("Invalid layer: {}", level),
}
}
#[inline]
pub fn remove(&mut self, id: Index) -> bool {
let (p0, p1, p2) = offsets(id);
if p0 >= self.layer0.len() {
return false;
}
if self.layer0[p0] & id.mask(SHIFT0) == 0 {
return false;
}
self.layer0[p0] &= !id.mask(SHIFT0);
if self.layer0[p0] != 0 {
return true;
}
self.layer1[p1] &= !id.mask(SHIFT1);
if self.layer1[p1] != 0 {
return true;
}
self.layer2[p2] &= !id.mask(SHIFT2);
if self.layer2[p2] != 0 {
return true;
}
self.layer3 &= !id.mask(SHIFT3);
return true;
}
#[inline]
pub fn contains(&self, id: Index) -> bool {
let p0 = id.offset(SHIFT1);
p0 < self.layer0.len() && (self.layer0[p0] & id.mask(SHIFT0)) != 0
}
#[inline]
pub fn contains_set(&self, other: &BitSet) -> bool {
for id in other.iter() {
if !self.contains(id) {
return false;
}
}
true
}
pub fn clear(&mut self) {
self.layer0.clear();
self.layer1.clear();
self.layer2.clear();
self.layer3 = 0;
}
}
pub trait BitSetLike {
fn get_from_layer(&self, layer: usize, idx: usize) -> usize {
match layer {
0 => self.layer0(idx),
1 => self.layer1(idx),
2 => self.layer2(idx),
3 => self.layer3(),
_ => panic!("Invalid layer: {}", layer),
}
}
fn is_empty(&self) -> bool {
self.layer3() == 0
}
fn layer3(&self) -> usize;
fn layer2(&self, i: usize) -> usize;
fn layer1(&self, i: usize) -> usize;
fn layer0(&self, i: usize) -> usize;
fn contains(&self, i: Index) -> bool;
fn iter(self) -> BitIter<Self>
where
Self: Sized,
{
let layer3 = self.layer3();
BitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
}
#[cfg(feature = "parallel")]
fn par_iter(self) -> BitParIter<Self>
where
Self: Sized,
{
BitParIter::new(self)
}
}
pub trait DrainableBitSet: BitSetLike {
fn remove(&mut self, i: Index) -> bool;
fn drain<'a>(&'a mut self) -> DrainBitIter<'a, Self>
where
Self: Sized,
{
let layer3 = self.layer3();
DrainBitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
}
}
impl<'a, T> BitSetLike for &'a T
where
T: BitSetLike + ?Sized,
{
#[inline]
fn layer3(&self) -> usize {
(*self).layer3()
}
#[inline]
fn layer2(&self, i: usize) -> usize {
(*self).layer2(i)
}
#[inline]
fn layer1(&self, i: usize) -> usize {
(*self).layer1(i)
}
#[inline]
fn layer0(&self, i: usize) -> usize {
(*self).layer0(i)
}
#[inline]
fn contains(&self, i: Index) -> bool {
(*self).contains(i)
}
}
impl<'a, T> BitSetLike for &'a mut T
where
T: BitSetLike + ?Sized,
{
#[inline]
fn layer3(&self) -> usize {
(**self).layer3()
}
#[inline]
fn layer2(&self, i: usize) -> usize {
(**self).layer2(i)
}
#[inline]
fn layer1(&self, i: usize) -> usize {
(**self).layer1(i)
}
#[inline]
fn layer0(&self, i: usize) -> usize {
(**self).layer0(i)
}
#[inline]
fn contains(&self, i: Index) -> bool {
(**self).contains(i)
}
}
impl<'a, T> DrainableBitSet for &'a mut T
where
T: DrainableBitSet,
{
#[inline]
fn remove(&mut self, i: Index) -> bool {
(**self).remove(i)
}
}
impl BitSetLike for BitSet {
#[inline]
fn layer3(&self) -> usize {
self.layer3
}
#[inline]
fn layer2(&self, i: usize) -> usize {
self.layer2.get(i).map(|&x| x).unwrap_or(0)
}
#[inline]
fn layer1(&self, i: usize) -> usize {
self.layer1.get(i).map(|&x| x).unwrap_or(0)
}
#[inline]
fn layer0(&self, i: usize) -> usize {
self.layer0.get(i).map(|&x| x).unwrap_or(0)
}
#[inline]
fn contains(&self, i: Index) -> bool {
self.contains(i)
}
}
impl DrainableBitSet for BitSet {
#[inline]
fn remove(&mut self, i: Index) -> bool {
self.remove(i)
}
}
impl PartialEq for BitSet {
#[inline]
fn eq(&self, rhv: &BitSet) -> bool {
if self.layer3 != rhv.layer3 {
return false;
}
if self.layer2.len() != rhv.layer2.len()
|| self.layer1.len() != rhv.layer1.len()
|| self.layer0.len() != rhv.layer0.len()
{
return false;
}
for i in 0..self.layer2.len() {
if self.layer2(i) != rhv.layer2(i) {
return false;
}
}
for i in 0..self.layer1.len() {
if self.layer1(i) != rhv.layer1(i) {
return false;
}
}
for i in 0..self.layer0.len() {
if self.layer0(i) != rhv.layer0(i) {
return false;
}
}
true
}
}
impl Eq for BitSet {}
#[cfg(test)]
mod tests {
use super::{BitSet, BitSetAnd, BitSetLike, BitSetNot};
#[test]
fn insert() {
let mut c = BitSet::new();
for i in 0..1_000 {
assert!(!c.add(i));
assert!(c.add(i));
}
for i in 0..1_000 {
assert!(c.contains(i));
}
}
#[test]
fn insert_100k() {
let mut c = BitSet::new();
for i in 0..100_000 {
assert!(!c.add(i));
assert!(c.add(i));
}
for i in 0..100_000 {
assert!(c.contains(i));
}
}
#[test]
fn remove() {
let mut c = BitSet::new();
for i in 0..1_000 {
assert!(!c.add(i));
}
for i in 0..1_000 {
assert!(c.contains(i));
assert!(c.remove(i));
assert!(!c.contains(i));
assert!(!c.remove(i));
}
}
#[test]
fn iter() {
let mut c = BitSet::new();
for i in 0..100_000 {
c.add(i);
}
let mut count = 0;
for (idx, i) in c.iter().enumerate() {
count += 1;
assert_eq!(idx, i as usize);
}
assert_eq!(count, 100_000);
}
#[test]
fn iter_odd_even() {
let mut odd = BitSet::new();
let mut even = BitSet::new();
for i in 0..100_000 {
if i % 2 == 1 {
odd.add(i);
} else {
even.add(i);
}
}
assert_eq!((&odd).iter().count(), 50_000);
assert_eq!((&even).iter().count(), 50_000);
assert_eq!(BitSetAnd(&odd, &even).iter().count(), 0);
}
#[test]
fn iter_random_add() {
use rand::prelude::*;
let mut set = BitSet::new();
let mut rng = thread_rng();
let limit = 1_048_576;
let mut added = 0;
for _ in 0..(limit / 10) {
let index = rng.gen_range(0, limit);
if !set.add(index) {
added += 1;
}
}
assert_eq!(set.iter().count(), added as usize);
}
#[test]
fn iter_clusters() {
let mut set = BitSet::new();
for x in 0..8 {
let x = (x * 3) << (::BITS * 2);
for y in 0..8 {
let y = (y * 3) << (::BITS);
for z in 0..8 {
let z = z * 2;
set.add(x + y + z);
}
}
}
assert_eq!(set.iter().count(), 8usize.pow(3));
}
#[test]
fn not() {
let mut c = BitSet::new();
for i in 0..10_000 {
if i % 2 == 1 {
c.add(i);
}
}
let d = BitSetNot(c);
for (idx, i) in d.iter().take(5_000).enumerate() {
assert_eq!(idx * 2, i as usize);
}
}
}
#[cfg(all(test, feature = "parallel"))]
mod test_parallel {
use super::{BitSet, BitSetAnd, BitSetLike};
use rayon::iter::ParallelIterator;
#[test]
fn par_iter_one() {
let step = 5000;
let tests = 1_048_576 / step;
for n in 0..tests {
let n = n * step;
let mut set = BitSet::new();
set.add(n);
assert_eq!(set.par_iter().count(), 1);
}
let mut set = BitSet::new();
set.add(1_048_576 - 1);
assert_eq!(set.par_iter().count(), 1);
}
#[test]
fn par_iter_random_add() {
use rand::prelude::*;
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
let mut set = BitSet::new();
let mut check_set = HashSet::new();
let mut rng = thread_rng();
let limit = 1_048_576;
for _ in 0..(limit / 10) {
let index = rng.gen_range(0, limit);
set.add(index);
check_set.insert(index);
}
let check_set = Arc::new(Mutex::new(check_set));
let missing_set = Arc::new(Mutex::new(HashSet::new()));
set.par_iter().for_each(|n| {
let check_set = check_set.clone();
let missing_set = missing_set.clone();
let mut check = check_set.lock().unwrap();
if !check.remove(&n) {
let mut missing = missing_set.lock().unwrap();
missing.insert(n);
}
});
let check_set = check_set.lock().unwrap();
let missing_set = missing_set.lock().unwrap();
if !check_set.is_empty() && !missing_set.is_empty() {
panic!(
"There were values that didn't get iterated: {:?}
There were values that got iterated, but that shouldn't be: {:?}",
*check_set, *missing_set
);
}
if !check_set.is_empty() {
panic!(
"There were values that didn't get iterated: {:?}",
*check_set
);
}
if !missing_set.is_empty() {
panic!(
"There were values that got iterated, but that shouldn't be: {:?}",
*missing_set
);
}
}
#[test]
fn par_iter_odd_even() {
let mut odd = BitSet::new();
let mut even = BitSet::new();
for i in 0..100_000 {
if i % 2 == 1 {
odd.add(i);
} else {
even.add(i);
}
}
assert_eq!((&odd).par_iter().count(), 50_000);
assert_eq!((&even).par_iter().count(), 50_000);
assert_eq!(BitSetAnd(&odd, &even).par_iter().count(), 0);
}
#[test]
fn par_iter_clusters() {
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
let mut set = BitSet::new();
let mut check_set = HashSet::new();
for x in 0..8 {
let x = (x * 3) << (::BITS * 2);
for y in 0..8 {
let y = (y * 3) << (::BITS);
for z in 0..8 {
let z = z * 2;
let index = x + y + z;
set.add(index);
check_set.insert(index);
}
}
}
let check_set = Arc::new(Mutex::new(check_set));
let missing_set = Arc::new(Mutex::new(HashSet::new()));
set.par_iter().for_each(|n| {
let check_set = check_set.clone();
let missing_set = missing_set.clone();
let mut check = check_set.lock().unwrap();
if !check.remove(&n) {
let mut missing = missing_set.lock().unwrap();
missing.insert(n);
}
});
let check_set = check_set.lock().unwrap();
let missing_set = missing_set.lock().unwrap();
if !check_set.is_empty() && !missing_set.is_empty() {
panic!(
"There were values that didn't get iterated: {:?}
There were values that got iterated, but that shouldn't be: {:?}",
*check_set, *missing_set
);
}
if !check_set.is_empty() {
panic!(
"There were values that didn't get iterated: {:?}",
*check_set
);
}
if !missing_set.is_empty() {
panic!(
"There were values that got iterated, but that shouldn't be: {:?}",
*missing_set
);
}
}
}