// Copyright (C) 2002-2003 Michael Anthony <mike at unlikely org>
// Permission to copy, use, modify, sell and distribute this software is
// granted provided this copyright notice appears in all copies.  This
// software is provided "as is" without express or implied warranty, and
// with no claim as to its suitability for any purpose.
//
// oslist: A list (implemented as a tree) which keeps order statistics,
// enabling one to fetch items by ordinal value in O(lg N) time rather than
// O(N) time.  The trade-off is that insertion and removal also take O(lg N)
// time, rather than O(1).
//
// Light experimentation shows that, assuming a random access pattern, a
// few hundred thousand items need to be involved before oslist provides a
// performance advantage.  The higher constants and O(lg N) insertion and
// removal go a long way toward offsetting the gains of O(lg N) rather than
// O(N) seeks.
//
// The definition of oslist can be found at the bottom of the file,
// following all the supporting general-purpose red-black tree code.
// The interface is STL-like.  Seeking based on ordinal values is performed
// with the nth method, rather than square brackets.  I chose a method
// since square brackets tend to imply different semantics than oslist
// exhibits: removing item #0 will cause all items in the list to shift
// down, and access via ordinal value is not a constant time operation.
//
// Optimization has not been performed.

#ifndef OSLIST_H
#define OSLIST_H

#include <cassert>

namespace tree_impl
{
    // Within this namespace resides a red-black tree implementation that
    // maintains order statistic information.

    //
    // The types.
    //

    enum color { black, red };

    struct node
    {
        node *p;
        node *l;
        node *r;
        color c;
        int n;
        node (node *nil, int nn) : p(0), l(nil), r(nil), c(black), n(nn) {}
    };

    template<typename K>
    struct knode : node
    {
        typedef K key_type;
        K k;
        knode (node *nil) : node (nil, 0) {}
        knode (node *nil, const K &kk) : node (nil, 1), k(kk) {}
    };

    //
    // These functions are (red-black) tree specific but are not specific to
    // search trees.
    //

    node*
    minimum (node *nil, node *x)
    {
        while (x->l != nil) x = x->l;
        return x;
    }

    node*
    maximum (node *nil, node *x)
    {
        while (x->r != nil) x = x->r;
        return x;
    }

    node*
    successor (node *nil, node *x)
    {
        if (x->r != nil)
        {
            return minimum (nil, x->r);
        }
        else
        {
            node *y = x->p;
            while (y != 0 && x == y->r)
            {
                x = y;
                y = y->p;
            }
            return y;
        }
    }

    node*
    predecessor (node *nil, node *x)
    {
        if (x->l != nil)
        {
            return maximum (nil, x->l);
        }
        else
        {
            node *y = x->p;
            while (y != 0 && x == y->l)
            {
                x = y;
                y = y->p;
            }
            return y;
        }
    }

    void
    left_rotate (node *nil, node **t, node *x)
    {
        assert (x->r != nil);
        // set y
        node *y = x->r;
        // turn y's left subtree into x's right subtree
        x->r = y->l;
        if (y->l != nil) y->l->p = x;
        // link x's parent to y
        y->p = x->p;
        // make y the child of x's parent or the tree root, if no parent
        if (x->p == 0) *t = y;
        else if (x == x->p->l) x->p->l = y;
        else x->p->r = y;
        // put x on y's left
        y->l = x;
        x->p = y;
        y->n = x->n;
        x->n = x->l->n + x->r->n + 1;
    }

    void
    right_rotate (node *nil, node **t, node *y)
    {
        assert (y->l != nil);
        // set x
        node *x = y->l;
        // turn x's right subtree into y's left subtree
        y->l = x->r;
        if (x->r != nil) x->r->p = y;
        // link y's parent to x
        x->p = y->p;
        // make y the child of x's parent or the tree root, if no parent
        if (y->p == 0) *t = x;
        else if (y == y->p->r) y->p->r = x;
        else y->p->l = x;
        // put y on x's right
        x->r = y;
        y->p = x;
        x->n = y->n;
        y->n = y->l->n + y->r->n + 1;
    }

    // call insert on x first!
    void
    rb_insert (node *nil, node **t, node *x)
    {
        node *y;
        assert (*t == x || x->p != 0); // make sure that insert has been called.
        x->c = red;
        while (x != *t && x->p->c == red)
        {
            if (x->p == x->p->p->l)
            {
                y = x->p->p->r;
                if (y->c == red)
                {
                    x->p->c = black;
                    y->c = black;
                    x->p->p->c = red;
                    x = x->p->p;
                }
                else
                {
                    if (x == x->p->r)
                    {
                        x = x->p;
                        left_rotate (nil, t, x);
                    }
                    x->p->c = black;
                    x->p->p->c = red;
                    right_rotate (nil, t, x->p->p);
                }
            }
            else
            {
                y = x->p->p->l;
                if (y->c == red)
                {
                    x->p->c = black;
                    y->c = black;
                    x->p->p->c = red;
                    x = x->p->p;
                }
                else
                {
                    if (x == x->p->l)
                    {
                        x = x->p;
                        right_rotate (nil, t, x);
                    }
                    x->p->c = black;
                    x->p->p->c = red;
                    left_rotate (nil, t, x->p->p);
                }
            }
        }
        (*t)->c = black;
    }

    void
    rb_delete_fixup (node *nil, node **t, node *x)
    {
        while (x != *t && x->c == black)
        {
            node *w;
            if (x == x->p->l)
            {
                w = x->p->r;
                if (w->c == red)
                {
                    w->c = black;
                    x->p->c = red;
                    left_rotate (nil, t, x->p);
                    w = x->p->r;
                }
                if (w->l->c == black && w->r->c == black)
                {
                    w->c = red;
                    x = x->p;
                }
                else
                {
                    if (w->r->c == black)
                    {
                        w->l->c = black;
                        w->c = red;
                        right_rotate (nil, t, w);
                        w = x->p->r;
                    }
                    w->c = x->p->c;
                    x->p->c = black;
                    w->r->c = black;
                    left_rotate (nil, t, x->p);
                    x = *t;
                }
            }
            else
            {
                w = x->p->l;
                if (w->c == red)
                {
                    w->c = black;
                    x->p->c = red;
                    right_rotate (nil, t, x->p);
                    w = x->p->l;
                }
                if (w->r->c == black && w->l->c == black)
                {
                    w->c = red;
                    x = x->p;
                }
                else
                {
                    if (w->l->c == black)
                    {
                        w->r->c = black;
                        w->c = red;
                        left_rotate (nil, t, w);
                        w = x->p->l;
                    }
                    w->c = x->p->c;
                    x->p->c = black;
                    w->l->c = black;
                    right_rotate (nil, t, x->p);
                    x = *t;
                }
            }
        }
        x->c = black;
    }

    template<typename N>
    N*
    rb_delete (N *nil, N **t, N *z)
    {
        // Figure out which node we're splicing out.
        N *y;
        if (z->l == nil || z->r == nil) y = z;
        else y = (N*) successor (nil, z);

        // Decrement the node counts in all its parents.
        for (node *p = y->p; p; p = p->p)
        {
            --p->n;
        }

        // Figure out which node we're replacing it with.
        N *x;
        if (y->l != nil) x = (N*) y->l;
        else x = (N*) y->r;

        // And replace it.
        x->p = y->p;
        if (y->p == 0)
        {
            *t = x;
        }
        else if (y == y->p->l)
        {
            y->p->l = x;
        }
        else
        {
            y->p->r = x;
        }
        if (y != z)
        {
            // "key[z] = key[y]"
            z->k = y->k;
        }

        // Now we fix our RB properties.
        if (y->c == black) rb_delete_fixup (nil, (node**)t, x);

        // And take out the trash.
        return y;
    }

    node* root (node *n)
    {
        while (n->p) n = n->p;
        return n;
    }

    template<typename N>
    void
    deep_delete (node *nil, N *n)
    {
        if (n != nil)
        {
            deep_delete (nil, (N*) n->l);
            deep_delete (nil, (N*) n->r);
            delete n;
        }
    }

    //
    // These functions are for maintaining a tree with ordered unique keys.
    //

    template<typename N, typename K>
    N*
    search (N *nil, N *t, K k)
    {
        while (t != nil)
        {
            if (t->k == k) break;
            else if (t->k > k) t = (N*) t->l;
            else t = (N*) t->r;
        }
        return t;
    }

    template<typename N>
    void
    unique_insert (N *nil, N **t, typename N::key_type v)
    {
        N *z = new N (nil, v);
        N *y = nil;
        N *x = *t;
        while (x != nil)
        {
            ++x->n;
            y = x;
            if (z->k < x->k) x = (N*) x->l;
            else if (z->k > x->k) x = (N*) x->r;
            else
            {
                // We've incremented the node counts on our way down.  Now that
                // we have found that there is no need to insert a node, we
                // need to undo this.
                node *u = x;
                while (u)
                {
                    --u->n;
                    u = u->p;
                }
                // REV: assign z->k to x->k, in case N::key_type has fields
                // that don't play into sorting/equality?
                return;
            }
        }
        if (y == nil)
        {
            z->p = 0;
            *t = z;
        }
        else
        {
            z->p = y;
            if (z->k < y->k) y->l = z;
            else y->r = z;
        }
        rb_insert (nil, (node**)t, z);
    }

    template<typename N>
    void
    unique_delete (N *nil, N **t, typename N::key_type v)
    {
        N *s = search (nil, *t, v);
        if (s != nil)
        {
            delete rb_delete (nil, t, s);
        }
    }

    //
    // These functions presume an un-ordered keys where the insertion order
    // given by the client is the relevant order.
    //

    template<typename N>
    void
    insert_before (N *nil, N **t, typename N::key_type v, node *b)
    {
        N *z = new N (nil, v);
        // REV: Perhaps we should do away with this special case by keeping
        // a sentinal end node that is always one past the end of the list.
        assert (b != nil);
        if (b->l == nil)
        {
            b->l = z;
            z->p = b;
        }
        else
        {
            node *y = b->l;
            while (y != nil)
            {
                b = y;
                y = y->r;
            }
            b->r = z;
            z->p = b;
        }
        // PERF: this could be done more efficiently if it was moved above.
        while (b != 0)
        {
            ++b->n;
            b = b->p;
        }
        rb_insert (nil, (node**)t, z);
    }

    template<typename N>
    void
    delete_node (N *nil, N **t, N *s, N **end_hack)
    {
        assert (s && s != nil);
        N *trash = rb_delete (nil, t, s);
        if (trash == *end_hack)
        {
            *end_hack = s;
        }
        delete trash;
    }

    //
    // These functions use the order statistics.
    //

    node*
    nth (node *nil, node *t, size_t n)
    {
        // We're 1-based, so add one to n since it's 0-based.
        ++n;

        while (t != nil)
        {
            int r = t->l->n + 1;
            if (r > n)
            {
                t = t->l;
            }
            else if (r < n)
            {
                t = t->r;
                n -= r;
            }
            else
            {
                break;
            }
        }
        return t;
    }

    int
    rank (node *t, node *x)
    {
        int r = x->l->n + 1;
        node *y = x;
        while (y != t)
        {
            if (y == y->p->r)
            {
                r += y->p->l->n + 1;
            }
            y = y->p;
        }
        return r;
    }
}

// Debugging code.  Keep going to find the definition of oslist.  You're
// almost there.
#ifndef NDEBUG
#define tree_impl_validate(n,x) ::tree_impl::dbg_validate(n,x)
#define tree_impl_print(n,x) ::tree_impl::dbg_print(n,x)
#include <iostream>
namespace tree_impl
{
    int
    dbg_validate (const node *nil, const node *x)
    {
        int n;
        if (x == nil)
        {
            assert (x->n == 0);
            assert (x->c == black);
            n = 1;
        }
        else
        {
            if (x->p != 0) assert (x->p->l == x || x->p->r == x);
            if (x->l != nil) assert (x->l->p == x);
            if (x->r != nil) assert (x->r->p == x);
            if (x->c == red)
            {
                assert (x->l->c == black);
                assert (x->r->c == black);
            }
            else assert (x->c == black);
            assert (x->n == x->l->n + x->r->n + 1);
            n = dbg_validate (nil, x->l);
            bool b = (n == dbg_validate (nil, x->r));
            assert (b);
            if (x->c == black) ++n;
        }
        return n;
    }
    template<typename K>
    void
    dbg_print (knode<K> *nil, knode<K> *x)
    {
        if (x != nil)
        {
            dbg_print (nil, (knode<K>*) x->l);
            std::cerr << (void*) x << ' ' << x->k << " (p=" << (void*) x->p;
            std::cerr << " l=";
            x->l == nil ? std::cerr << "nil" : std::cerr << (void*) x->l;
            std::cerr << " r=";
            x->r == nil ? std::cerr << "nil" : std::cerr << (void*) x->r;
            std::cerr << " c=" << (x->c == red ? "red" : "black")
                      << " n=" << x->n
                      << " r=" << rank (root (x), x)
                      << ")" << std::endl;
            dbg_print (nil, (knode<K>*) x->r);
        }
    }
}
#else
#define tree_impl_validate(n,x) 0
#define tree_impl_print(n,x)
#endif

template<typename T>
class oslist
{
public:
    typedef T key_type;

    friend class iterator
    {
    public:
        iterator () : m_n (0), m_l (0) { }

        typedef T key_type;

        key_type& operator* ()
        {
            assert (m_n);
            assert (m_n != m_l->m_nil);
            assert (m_n != m_l->m_end);
            return m_n->k;
        }

        iterator& operator++ ()
        {
            assert (m_n != m_l->m_end);
            m_n = (tree_impl::knode<key_type>*) tree_impl::successor (
                m_l->m_nil, m_n);
        }

        iterator& operator-- ()
        {
            m_n = tree_impl::predecessor (m_l->m_nil, m_n);
            assert (m_n != m_l->m_nil);
        }

        bool operator== (const iterator &o)
        {
            return (m_n == o.m_n);
        }

        bool operator!= (const iterator &o)
        {
            return (m_n != o.m_n);
        }

    private:
        friend class oslist<key_type>;

        iterator (tree_impl::knode<key_type> *n, oslist<key_type> *l) :
            m_n (n), m_l (l)
        {
        }

        tree_impl::knode<key_type> *m_n;
        oslist<key_type> *m_l;
    };

    oslist () :
        m_nil (new tree_impl::knode<key_type> (0)),
        m_end (new tree_impl::knode<key_type> (m_nil, key_type ())),
        m_root (m_end)
    {
    }

    ~oslist ()
    {
        tree_impl::deep_delete (m_nil, m_root);
        delete m_nil;
    }

    size_t size () const
    {
        return tree_impl::rank (m_root, m_end) - 1;
    }

    iterator begin ()
    {
        return iterator (
            (tree_impl::knode<key_type>*) tree_impl::minimum (m_nil, m_root),
            this);
    }

    iterator end ()
    {
        return iterator (m_end, this);
    }

    void erase (iterator it)
    {
        assert (it.m_n);
        assert (it.m_n != m_nil);
        assert (it.m_n != m_end);
        assert (it.m_l == this);
        tree_impl::delete_node (m_nil, &m_root, it.m_n, &m_end);
    }

    void insert (iterator it, const key_type &k)
    {
        assert (it.m_n);
        assert (it.m_n != m_nil);
        assert (it.m_l == this);
        tree_impl::insert_before (m_nil, &m_root, k, it.m_n);
    }

    iterator nth (size_t n)
    {
        return iterator (
            (tree_impl::knode<key_type>*) tree_impl::nth (m_nil, m_root, n),
            this);
    }

private:
    // Not yet.
    oslist (const oslist&);
    oslist& operator= (const oslist&);

public:
    tree_impl::knode<key_type> *m_nil;
    tree_impl::knode<key_type> *m_end;
    tree_impl::knode<key_type> *m_root;
};

#endif
