1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
/*
 * Released under the terms of the Apache 2.0 license with LLVM
 * exception. See `LICENSE` for details.
 */

//! Index sets: sets of integers that represent indices into a space.

use fxhash::FxHashMap;
use std::cell::Cell;

const SMALL_ELEMS: usize = 12;

/// A hybrid large/small-mode sparse mapping from integer indices to
/// elements.
///
/// The trailing `(u32, u64)` elements in each variant is a one-item
/// cache to allow fast access when streaming through.
#[derive(Clone, Debug)]
enum AdaptiveMap {
    Small {
        len: u32,
        keys: [u32; SMALL_ELEMS],
        values: [u64; SMALL_ELEMS],
    },
    Large(FxHashMap<u32, u64>),
}

const INVALID: u32 = 0xffff_ffff;

impl AdaptiveMap {
    fn new() -> Self {
        Self::Small {
            len: 0,
            keys: [INVALID; SMALL_ELEMS],
            values: [0; SMALL_ELEMS],
        }
    }

    /// Expand into `Large` mode if we are at capacity and have no
    /// zero-value pairs that can be trimmed.
    #[inline(never)]
    fn expand(&mut self) {
        match self {
            &mut Self::Small {
                ref mut len,
                ref mut keys,
                ref mut values,
            } => {
                // Note: we *may* remain as `Small` if there are any
                // zero elements. Try removing them first, before we
                // commit to a memory allocation.
                if values.iter().any(|v| *v == 0) {
                    let mut out = 0;
                    for i in 0..(*len as usize) {
                        if values[i] == 0 {
                            continue;
                        }
                        if out < i {
                            keys[out] = keys[i];
                            values[out] = values[i];
                        }
                        out += 1;
                    }
                    *len = out as u32;
                } else {
                    let mut map = FxHashMap::default();
                    for i in 0..(*len as usize) {
                        map.insert(keys[i], values[i]);
                    }
                    *self = Self::Large(map);
                }
            }
            _ => {}
        }
    }
    #[inline(always)]
    fn get_or_insert<'a>(&'a mut self, key: u32) -> &'a mut u64 {
        // Check whether the key is present and we are in small mode;
        // if no to both, we need to expand first.
        let (needs_expand, small_mode_idx) = match self {
            &mut Self::Small { len, ref keys, .. } => {
                // Perform this scan but do not return right away;
                // doing so runs into overlapping-borrow issues
                // because the current non-lexical lifetimes
                // implementation is not able to see that the `self`
                // mutable borrow on return is only on the
                // early-return path.
                let small_mode_idx = keys.iter().take(len as usize).position(|k| *k == key);
                let needs_expand = small_mode_idx.is_none() && len == SMALL_ELEMS as u32;
                (needs_expand, small_mode_idx)
            }
            _ => (false, None),
        };

        if needs_expand {
            debug_assert!(small_mode_idx.is_none());
            self.expand();
        }

        match self {
            &mut Self::Small {
                ref mut len,
                ref mut keys,
                ref mut values,
            } => {
                // If we found the key already while checking whether
                // we need to expand above, use that index to return
                // early.
                if let Some(i) = small_mode_idx {
                    return &mut values[i];
                }
                // Otherwise, the key must not be present; add a new
                // entry.
                debug_assert!(*len < SMALL_ELEMS as u32);
                let idx = *len;
                *len += 1;
                keys[idx as usize] = key;
                values[idx as usize] = 0;
                &mut values[idx as usize]
            }
            &mut Self::Large(ref mut map) => map.entry(key).or_insert(0),
        }
    }
    #[inline(always)]
    fn get_mut(&mut self, key: u32) -> Option<&mut u64> {
        match self {
            &mut Self::Small {
                len,
                ref keys,
                ref mut values,
            } => {
                for i in 0..len {
                    if keys[i as usize] == key {
                        return Some(&mut values[i as usize]);
                    }
                }
                None
            }
            &mut Self::Large(ref mut map) => map.get_mut(&key),
        }
    }
    #[inline(always)]
    fn get(&self, key: u32) -> Option<u64> {
        match self {
            &Self::Small {
                len,
                ref keys,
                ref values,
            } => {
                for i in 0..len {
                    if keys[i as usize] == key {
                        let value = values[i as usize];
                        return Some(value);
                    }
                }
                None
            }
            &Self::Large(ref map) => {
                let value = map.get(&key).cloned();
                value
            }
        }
    }
    fn iter<'a>(&'a self) -> AdaptiveMapIter<'a> {
        match self {
            &Self::Small {
                len,
                ref keys,
                ref values,
            } => AdaptiveMapIter::Small(&keys[0..len as usize], &values[0..len as usize]),
            &Self::Large(ref map) => AdaptiveMapIter::Large(map.iter()),
        }
    }
}

enum AdaptiveMapIter<'a> {
    Small(&'a [u32], &'a [u64]),
    Large(std::collections::hash_map::Iter<'a, u32, u64>),
}

impl<'a> std::iter::Iterator for AdaptiveMapIter<'a> {
    type Item = (u32, u64);
    fn next(&mut self) -> Option<Self::Item> {
        match self {
            &mut Self::Small(ref mut keys, ref mut values) => {
                if keys.is_empty() {
                    None
                } else {
                    let (k, v) = ((*keys)[0], (*values)[0]);
                    *keys = &(*keys)[1..];
                    *values = &(*values)[1..];
                    Some((k, v))
                }
            }
            &mut Self::Large(ref mut it) => it.next().map(|(&k, &v)| (k, v)),
        }
    }
}

/// A conceptually infinite-length set of indices that allows union
/// and efficient iteration over elements.
#[derive(Clone)]
pub struct IndexSet {
    elems: AdaptiveMap,
    cache: Cell<(u32, u64)>,
}

const BITS_PER_WORD: usize = 64;

impl IndexSet {
    pub fn new() -> Self {
        Self {
            elems: AdaptiveMap::new(),
            cache: Cell::new((INVALID, 0)),
        }
    }

    #[inline(always)]
    fn elem(&mut self, bit_index: usize) -> &mut u64 {
        let word_index = (bit_index / BITS_PER_WORD) as u32;
        if self.cache.get().0 == word_index {
            self.cache.set((INVALID, 0));
        }
        self.elems.get_or_insert(word_index)
    }

    #[inline(always)]
    fn maybe_elem_mut(&mut self, bit_index: usize) -> Option<&mut u64> {
        let word_index = (bit_index / BITS_PER_WORD) as u32;
        if self.cache.get().0 == word_index {
            self.cache.set((INVALID, 0));
        }
        self.elems.get_mut(word_index)
    }

    #[inline(always)]
    fn maybe_elem(&self, bit_index: usize) -> Option<u64> {
        let word_index = (bit_index / BITS_PER_WORD) as u32;
        if self.cache.get().0 == word_index {
            Some(self.cache.get().1)
        } else {
            self.elems.get(word_index)
        }
    }

    #[inline(always)]
    pub fn set(&mut self, idx: usize, val: bool) {
        let bit = idx % BITS_PER_WORD;
        if val {
            *self.elem(idx) |= 1 << bit;
        } else if let Some(word) = self.maybe_elem_mut(idx) {
            *word &= !(1 << bit);
        }
    }

    pub fn assign(&mut self, other: &Self) {
        self.elems = other.elems.clone();
        self.cache = other.cache.clone();
    }

    #[inline(always)]
    pub fn get(&self, idx: usize) -> bool {
        let bit = idx % BITS_PER_WORD;
        if let Some(word) = self.maybe_elem(idx) {
            (word & (1 << bit)) != 0
        } else {
            false
        }
    }

    pub fn union_with(&mut self, other: &Self) -> bool {
        let mut changed = 0;
        for (word_idx, bits) in other.elems.iter() {
            if bits == 0 {
                continue;
            }
            let word_idx = word_idx as usize;
            let self_word = self.elem(word_idx * BITS_PER_WORD);
            changed |= bits & !*self_word;
            *self_word |= bits;
        }
        changed != 0
    }

    pub fn iter<'a>(&'a self) -> impl Iterator<Item = usize> + 'a {
        self.elems.iter().flat_map(|(word_idx, bits)| {
            let word_idx = word_idx as usize;
            set_bits(bits).map(move |i| BITS_PER_WORD * word_idx + i)
        })
    }

    /// Is the adaptive data structure in "small" mode? This is meant
    /// for testing assertions only.
    pub(crate) fn is_small(&self) -> bool {
        match &self.elems {
            &AdaptiveMap::Small { .. } => true,
            _ => false,
        }
    }
}

fn set_bits(bits: u64) -> impl Iterator<Item = usize> {
    let iter = SetBitsIter(bits);
    iter
}

pub struct SetBitsIter(u64);

impl Iterator for SetBitsIter {
    type Item = usize;
    fn next(&mut self) -> Option<usize> {
        // Build an `Option<NonZeroU64>` so that on the nonzero path,
        // the compiler can optimize the trailing-zeroes operator
        // using that knowledge.
        std::num::NonZeroU64::new(self.0).map(|nz| {
            let bitidx = nz.trailing_zeros();
            self.0 &= self.0 - 1; // clear highest set bit
            bitidx as usize
        })
    }
}

impl std::fmt::Debug for IndexSet {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        let vals = self.iter().collect::<Vec<_>>();
        write!(f, "{:?}", vals)
    }
}

#[cfg(test)]
mod test {
    use super::IndexSet;

    #[test]
    fn test_set_bits_iter() {
        let mut vec = IndexSet::new();
        let mut sum = 0;
        for i in 0..1024 {
            if i % 17 == 0 {
                vec.set(i, true);
                sum += i;
            }
        }

        let mut checksum = 0;
        for bit in vec.iter() {
            debug_assert!(bit % 17 == 0);
            checksum += bit;
        }

        debug_assert_eq!(sum, checksum);
    }

    #[test]
    fn test_expand_remove_zero_elems() {
        let mut vec = IndexSet::new();
        // Set 12 different words (this is the max small-mode size).
        for i in 0..12 {
            vec.set(64 * i, true);
        }
        // Now clear a bit, and set a bit in a different word. We
        // should still be in small mode.
        vec.set(64 * 5, false);
        vec.set(64 * 100, true);
        debug_assert!(vec.is_small());
    }
}