// Copyright (C) 2002-2003 Michael Anthony <mike at unlikely org>
// Permission to copy, use, modify, sell and distribute this software is
// granted provided this copyright notice appears in all copies.  This
// software is provided "as is" without express or implied warranty, and
// with no claim as to its suitability for any purpose.

#ifndef NSET_H
#define NSET_H

#include <cassert>
#include <algorithm>


// A container for sets of natural numbers.  Operations are O(log N) (more
// precisely, O(logB N)) with a small constant, except find and count,
// which are O(1), and set equality, which is O(N log N) (but which should
// be O(N)).
//
// Assuming that tree node objects would consume 20 bytes of storage and
// you use a branching factor, B, of 128, this data structure is as space
// efficient as a tree when it is approximately .6% "full."  That is, when
// .6% of the numbers in the range between the 0 and set maximum are in the
// set.  It gets more space efficient the fuller it is.  Reducing the
// branching factor increases the space requirements.
//
// The container is implemented as a heirarchy of bitmaps.  At the base of
// the tree, there is a bitmap containing a bit for every whole number in
// the range [0, set-max].  Each bit at offset N in a higher level tracks
// whether the Nth B bitmaps below it contain 1s.
//
// B should be a multiple of word size (in bits); ideally, it would also be
// a multiple of cache line size.
//
// Warning: The maximum value storable in this structure is less than the
// maximum value that a long can hold.  It is the smallest power of B less
// than the maximum value of key_type.  Of course, you don't have that much
// RAM anyhow.
//
// REV: Should people really be futzing with B?  At some point, increased
// tree depth is probably no longer worth making fewer comparisons.  On the
// other end, the data structure becomes an ordinary bitmap.
//
// TODO: The origin should not have to be zero.

template<size_t B = 128>
class nset
{
public:
    typedef size_t key_type;

private:
    friend class iter_base
    {
    public:
        typedef nset::key_type key_type;

        iter_base ();
        iter_base (const nset<B>*, key_type);

        void up ();
        void down ();
        key_type value () const;
        int cmp (const iter_base &o) const;

    private:
        const nset *m_nset;
        key_type m_v;
    };

public:
    friend class iterator : private iter_base
    {
    public:
        typedef iter_base::key_type key_type;

        iterator () : iter_base () {}

        iterator operator++ (int);
        iterator& operator++ ();
        iterator operator-- (int);
        iterator& operator-- ();
        bool operator== (const iterator&) const;
        bool operator!= (const iterator&) const;
        bool operator<= (const iterator&) const;
        bool operator< (const iterator&) const;
        bool operator>= (const iterator&) const;
        bool operator> (const iterator&) const;
        key_type operator* () const;

    private:
        friend nset<B>;
        iterator (const nset<B>*, key_type);
    };

    friend class reverse_iterator : private iter_base
    {
    public:
        typedef iter_base::key_type key_type;

        reverse_iterator () : iter_base () {}

        reverse_iterator operator++ (int);
        reverse_iterator& operator++ ();
        reverse_iterator operator-- (int);
        reverse_iterator& operator-- ();
        bool operator== (const reverse_iterator&) const;
        bool operator!= (const reverse_iterator&) const;
        bool operator<= (const reverse_iterator&) const;
        bool operator< (const reverse_iterator&) const;
        bool operator>= (const reverse_iterator&) const;
        bool operator> (const reverse_iterator&) const;
        key_type operator* () const;

    private:
        friend nset<B>;
        reverse_iterator (const nset<B>*, key_type);
    };

    typedef iterator const_iterator;
    typedef reverse_iterator const_reverse_iterator;

    static size_t branching_factor ();

    nset (size_t size = 1);
    nset (const nset<B>&);
    template<typename T>
    nset (T begin, T end);
    ~nset ();

    nset& operator= (const nset&);
    std::pair<iterator,bool> insert (key_type);
    size_t erase (key_type);
    void erase (iterator);
    void clear ();
    void resize (size_t);
    void swap (nset<B> &o);
    size_t count (key_type) const;
    const_iterator find (key_type) const;
    const_iterator begin () const;
    const_iterator end () const;
    const_reverse_iterator rbegin () const;
    const_reverse_iterator rend () const;
    bool empty () const;
    size_t max_size () const;
    size_t size () const;

private:
    typedef unsigned char uc;

    size_t m_height;    // Height of the tree.  One level: height is 0.
    size_t m_limit;     // Largest number we can store.
    size_t m_data_size; // Number of bytes in m_data.
    uc *m_data;         // All of the bitmaps.
    uc **m_trees;       // Array of pointers to bitmaps.
    size_t m_size;      // Number of elements in the set.

    void init (size_t size);
};


// Not bothering the FPU for pow shows a 78% speed-up on a PII.
template<typename T>
inline T
ipow (T x, T n)
{
    register T s = x;
    register T xx = s;
    register T nn = n;

    if (nn == 0) return 1;
    while (nn > 1)
    {
        xx *= s;
        --nn;
    }

    return xx;
}


// Not bothering the FPU for ceil shows a 38% speed-up on a PII.
template<typename T>
inline T
iceildiv (T n, T d)
{
    register T r = n / d;

    if (n % d) ++r;

    return r;
}


template<size_t B>
bool
operator== (const nset<B> &l, const nset<B> &r)
{
    if (l.size () != r.size ()) return false;

    nset<B>::const_iterator lit (l.begin ());
    nset<B>::const_iterator lend (l.end ());
    nset<B>::const_iterator rit (r.begin ());
    nset<B>::const_iterator rend (r.end ());
    bool status = true;

    // PERF: This should be a member function.  We should statically check
    // that the alignment is good.  We should check that the bounds are the
    // same and then, assuming that they are, perform this operation as a
    // memcmp on the lowest level bitmap.
    for (; status && lit != lend && rit != rend; ++lit, ++rit)
    {
        if (*lit != *rit) status = false;
    }

    return ((status == false) ? status : (lit == lend && rit == rend));
}


template<size_t B>
nset<B>::nset (size_t sz)
{
    this->init (sz);
}


template<size_t B>
nset<B>::nset (const nset<B> &o)
{
    this->init (o.m_limit + 1);
    assert (o.m_limit == m_limit);
    assert (o.m_height == m_height);
    copy (o.m_data, o.m_data + o.m_data_size, m_data);
    m_size = o.m_size;
}


template<size_t B>
template<typename T>
nset<B>::nset (T begin, T end)
{
    // PERF: If we specialized this for bidirectional iterators (how? we
    // could at least do pointers, but what about the rest?), we could make
    // this more efficient, since we would know the final extents of our set.
    this->init (B*B);
    while (begin != end)
    {
        this->insert (*begin);
        ++begin;
    }
}

#if 0
template<size_t B>
template<all pointers, damnit>
nset<B>::nset (const T *begin, const T *end)
{
    // PERF: How can we specialize so that this works for all bidi
    // iterators?
    this->init (end - begin);
    while (begin != end)
    {
        this->insert (*begin);
        ++begin;
    }
}
#endif

template<size_t B>
nset<B>::~nset ()
{
    delete [] m_data;
    delete [] m_trees;
}


template<size_t B>
nset<B>&
nset<B>::operator= (const nset &o)
{
    nset<B> n (o);
    this->swap (n);
    return *this;
}


template<size_t B>
std::pair<nset<B>::iterator,bool>
nset<B>::insert (key_type n)
{
    assert (n < size_t (-1));

    if (n > m_limit) this->resize (n + 1);

    bool r;
    size_t l = m_height;

    r = (m_trees[l][n / 8U] & (1 << (n % 8U)));
    m_trees[l][n / 8U] |= (1 << (n % 8U));
    if (l) do
    {
        --l;
        n /= B;
        m_trees[l][n / 8U] |= (1 << (n % 8U));
    }
    while (l > 0);

    if (!r) ++m_size;

    return make_pair (iterator (this, n), r);
}


template<size_t B>
size_t
nset<B>::erase (key_type n)
{
    assert (n < size_t (-1));

    if (n > m_limit) return 0;

    size_t m;
    size_t l = m_height;
    bool b = false;
    bool r;

    r = (m_trees[l][n / 8U] & (1 << (n % 8U)));

    if (l) do
    {
        m_trees[l][n / 8U] &= ~(1 << (n % 8U));
        n = (n / B) * B;
        m = n + B - 1;
        for (b = false; !b && n <= m; ++n)
        {
            b |= (m_trees[l][n / 8U] & (1 << (n % 8U)));
        }
        --l;
        n = (n - 1) / B;
    }
    while (!b && l > 0);

    if (!b)
    {
        assert (l == 0);
        m_trees[0][0] = 0;
    }

    if (r) --m_size;

    return r;
}


template<size_t B>
void
nset<B>::erase (iterator it)
{
    this->erase (*it);
}


template<size_t B>
void
nset<B>::clear ()
{
    fill (m_data, m_data + m_data_size, 0);
    m_size = 0;
}


template<size_t B>
void
nset<B>::resize (size_t x)
{
    if ((x - 1) <= m_limit) return;

    nset<B> o (x);

    assert (o.m_height > m_height);
    assert (x <= o.m_limit + 1);
    if (m_trees[0][0])
    {
        size_t diff = o.m_height - m_height;
        for (size_t i = 0; i < diff; ++i)
        {
            o.m_trees[i][0] = 1;
        }
        for (size_t i = 0; i <= m_height; ++i)
        {
            copy (
                m_trees[i],
                m_trees[i]
                + iceildiv (ipow (B, i), 8U),
                o.m_trees[i + diff]);
        }
    }
    o.m_size = m_size;

    this->swap (o);
}


template<size_t B>
void
nset<B>::swap (nset<B> &o)
{
    ::swap (m_height, o.m_height);
    ::swap (m_limit, o.m_limit);
    ::swap (m_data_size, o.m_data_size);
    ::swap (m_data, o.m_data);
    ::swap (m_trees, o.m_trees);
    ::swap (m_size, o.m_size);
}


template<size_t B>
size_t
nset<B>::count (key_type v) const
{
    return m_limit >= v && (m_trees[m_height][v / 8U] & (1 << (v % 8U)));
}


template<size_t B>
nset<B>::const_iterator
nset<B>::find (key_type v) const
{
    if (this->count (v))
    {
        return iterator (this, v);
    }
    else
    {
        return this->end ();
    }
}


template<size_t B>
nset<B>::iterator
nset<B>::begin () const
{
    iterator it (this, 0);
    if ((m_trees[m_height][0] & 1) == 0)
        ++it;
    return it;
}


template<size_t B>
nset<B>::iterator
nset<B>::end () const
{
    return iterator (this, m_limit + 1);
}


template<size_t B>
nset<B>::const_reverse_iterator
nset<B>::rbegin () const
{
    reverse_iterator it (this, m_limit);
    if ((m_trees[m_height][m_limit / 8U]
         & (1 << (m_limit % 8U))) == 0)
        ++it;
    return it;
}


template<size_t B>
void nset<B>::init (size_t sz)
{
    size_t h;
    assert (sz);
    size_t l = sz - 1;

    m_height = 0;
    for (h = l; h; h /= B)
        ++m_height;
    m_limit = ipow (B, m_height) - 1;
    if (m_limit < l)
    {
        ++m_height;
        m_limit = ipow (B, m_height) - 1;
    }
    for (h = m_height, m_data_size = 1; h; --h)
    {
        m_data_size += iceildiv (ipow (B, m_height), 8U);
    }
    m_data = new uc [m_data_size];
    m_trees = new uc* [m_height+1];
    m_size = 0;

    fill (m_data, m_data + m_data_size, 0);

    size_t accum = 0;

    for (size_t i = 0; i <= m_height; ++i)
    {
        m_trees[i] = m_data + accum;
        accum += iceildiv (ipow (B, i), 8U);
    }
}


template<size_t B>
nset<B>::const_reverse_iterator
nset<B>::rend () const
{
    return reverse_iterator (this, key_type (-1));
}


template<size_t B>
bool
nset<B>::empty () const
{
    return m_trees[0][0] == 0;
}


template<size_t B>
size_t
nset<B>::max_size () const
{
    return size_t (-1) / 2; // Conservative guess.
}


template<size_t B>
size_t
nset<B>::size () const
{
    return m_size;
}


template<size_t B>
size_t
nset<B>::branching_factor ()
{
    return B;
}


template<size_t B>
nset<B>::iter_base::iter_base () :
    m_nset (0),
    m_v (0)
{
}


template<size_t B>
nset<B>::iter_base::iter_base (const nset<B> *n, key_type l) :
    m_nset (n),
    m_v (l)
{
}


template<size_t B>
void
nset<B>::iter_base::up ()
{
    if (m_v == size_t (-1))
    {
        m_v = 0;
        if ((m_nset->m_trees[m_nset->m_height][0] & 1) == 1)
        {
            return;
        }
    }
    else
    {
        assert (m_v <= m_nset->m_limit + 1);
    }

    if (m_v == m_nset->m_limit)
    {
        ++m_v;
        return;
    }

    size_t m;
    size_t l = m_nset->m_height;
    size_t n = m_v + 1;
    bool b = false;

    for (; l;)
    {
        m = ((n / B) * B) + B - 1;
        for (b = false; !b && n <= m; b || ++n)
        {
            b = (m_nset->m_trees[l][n / 8U]
                 & (1 << (n % 8U)));
        }
        if (b || n == size_t (ipow (B, l)))
        {
            break;
        }
        --l;
        n = (m / B) + 1;
    }

    if (!b)
    {
        m_v = m_nset->m_limit + 1;
        return;
    }

    for (++l; l <= m_nset->m_height; ++l)
    {
        m = ((n + 1) * B) - 1;
        n = n * B;
        for (b = false; !b && n <= m; ++n)
        {
            b = (m_nset->m_trees[l][n / 8U]
                 & (1 << (n % 8U)));
            if (b) break;
        }
        assert (b);
    }

    m_v = n;
}


template<size_t B>
void
nset<B>::iter_base::down ()
{
    assert (m_v != size_t (-1));
    if (m_v == 0)
    {
        m_v = size_t (-1);
        return;
    }

    size_t m;
    size_t l = m_nset->m_height;
    size_t n = m_v - 1;
    bool b = false;

    for (; l;)
    {
        m = (n / B) * B;
        b = false;
        for (;;)
        {
            b = (m_nset->m_trees[l][n / 8U]
                 & (1 << (n % 8U)));
            if (b || n == m) break;
            --n;
        }
        if (b)
        {
            break;
        }
        --l;
        n = (m / B) - 1;
    }

    if (!b)
    {
        m_v = size_t (-1);
        return;
    }

    for (++l; l <= m_nset->m_height; ++l)
    {
        m = n * B;
        n = ((n + 1) * B) - 1;
        b = false;
        for (;;)
        {
            b = (m_nset->m_trees[l][n / 8U]
                 & (1 << (n % 8U)));
            if (b || n == m) break;
            --n;
        }
        assert (b);
    }

    m_v = n;
}


template<size_t B>
nset<B>::iter_base::key_type
nset<B>::iter_base::value () const
{
    assert (m_v <= m_nset->m_limit);
    assert (m_nset->m_trees[m_nset->m_height][m_v / 8U]
            & (1 << m_v % 8U));
    return m_v;
}


template<size_t B>
int
nset<B>::iter_base::cmp (const iter_base &o) const
{
    assert (o.m_nset == m_nset);
    return ((m_v+1) == (o.m_v+1) ? 0 : (m_v+1) > (o.m_v+1) ? 1 : -1);
}


template<size_t B>
nset<B>::iterator
nset<B>::iterator::operator++ (int)
{
    iterator o (*this);
    this->iter_base::up ();
    return o;
}


template<size_t B>
nset<B>::iterator&
nset<B>::iterator::operator++ ()
{
    this->iter_base::up ();
    return *this;
}


template<size_t B>
nset<B>::iterator
nset<B>::iterator::operator-- (int)
{
    iterator o (*this);
    this->iter_base::down ();
    return o;
}


template<size_t B>
nset<B>::iterator&
nset<B>::iterator::operator-- ()
{
    this->iter_base::down ();
    return *this;
}


template<size_t B>
bool
nset<B>::iterator::operator== (const iterator &o) const
{
    return this->iter_base::cmp (o) == 0;
}


template<size_t B>
bool
nset<B>::iterator::operator!= (const iterator &o) const
{
    return this->iter_base::cmp (o) != 0;
}


template<size_t B>
bool
nset<B>::iterator::operator<= (const iterator &o) const
{
    return this->iter_base::cmp (o) <= 0;
}


template<size_t B>
bool
nset<B>::iterator::operator< (const iterator &o) const
{
    return this->iter_base::cmp (o) < 0;
}


template<size_t B>
bool
nset<B>::iterator::operator>= (const iterator &o) const
{
    return this->iter_base::cmp (o) >= 0;
}


template<size_t B>
bool
nset<B>::iterator::operator> (const iterator &o) const
{
    return this->iter_base::cmp (o) > 0;
}


template<size_t B>
nset<B>::iterator::key_type
nset<B>::iterator::operator* () const
{
    return this->iter_base::value ();
}


template<size_t B>
nset<B>::iterator::iterator (const nset<B> *n, key_type v) :
    iter_base (n, v)
{
}


template<size_t B>
nset<B>::reverse_iterator
nset<B>::reverse_iterator::operator++ (int)
{
    reverse_iterator o (*this);
    this->iter_base::down ();
    return o;
}


template<size_t B>
nset<B>::reverse_iterator&
nset<B>::reverse_iterator::operator++ ()
{
    this->iter_base::down ();
    return *this;
}


template<size_t B>
nset<B>::reverse_iterator
nset<B>::reverse_iterator::operator-- (int)
{
    reverse_iterator o (*this);
    this->iter_base::up ();
    return o;
}


template<size_t B>
nset<B>::reverse_iterator&
nset<B>::reverse_iterator::operator-- ()
{
    this->iter_base::up ();
    return *this;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator== (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) == 0;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator!= (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) != 0;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator<= (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) <= 0;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator< (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) < 0;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator>= (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) >= 0;
}


template<size_t B>
bool
nset<B>::reverse_iterator::operator> (const reverse_iterator &o) const
{
    return this->iter_base::cmp (o) > 0;
}


template<size_t B>
nset<B>::reverse_iterator::key_type
nset<B>::reverse_iterator::operator* () const
{
    return this->iter_base::value ();
}


template<size_t B>
nset<B>::reverse_iterator::reverse_iterator (
    const nset<B> *n, key_type v) :
    iter_base (n, v)
{
}

#endif
