Pages

Saturday, January 28, 2012

A probable mistake while overloading the '<' operator

I was not aware of the following basic thing. See the code segment bellow:


struct T
{
    int a, b;
    
    T(){}
    T(int _a, int _b){a=_a;b=_b;}
    
    bool operator < (T obj) const
    {
        return a < obj.a;
    }
};

int main()
{
 set<T> S;
 
 S.insert(T(0,0));
 S.insert(T(0,1));
 
 cout<<S.size()<<"\n";
 
    return 0;
}
The output of the given code is: 1 (I was surprised!)
The overloaded operator '<' can not distinguish between T(0,0) and T(0,1) (pretty basic.. huh!)

I learnt this when writing the State class for a Dijkstra's code. I was using STL set as the priority queue. Some nodes were "mysteriously" vanishing!

The overloaded operator should look like this:

    bool operator < (T obj) const
    {
        if(a!=obj.a)
            return a < obj.a;
        else
            return b < obj.b;
    }
Although this doesn't matter if you use a STL priority_queue. It can store more than one copy of the 'same' thing. But sometimes you need to use set, specially when deleting-nodes from the priority queue is important.

Friday, January 27, 2012

SPOJ ORDERSET (With Treap)

I LOVE TREAP !!


Honestly! Treap is the most beautiful Data Structure I know!
In my previous post I solved This Problem with Splay Tree which was more than 500 lines of code and took 5.98 seconds to run on SPOJ.

I solved the same problem with Treap this time and it's only a little over 200 lines that runs in 4.42 seconds on SPOJ.

The following implementation is based on the code given Here.


// SPOJ: ORDERSET
// using Treap
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<ctype.h>
#include<stdlib.h>
#include<time.h>
#include<assert.h>

#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<iostream>
#include<algorithm>
#include<string>

using namespace std;

#define FOR(i,n) for(int i=0;i<(n);++i)
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define CLR(a,x) memset(a,(x),sizeof(a))

#define INF (1<<30)

typedef long long LL;
typedef pair<int,int> pii;

struct Node {
    int key;
    int cnt;
    int priority;

    Node *left, *right;

    Node(){cnt = 0; priority = 0; left = right = NULL;}
    Node(int _key){cnt = 1; key = _key; priority = rand(); left = right = NULL;}
    Node(int _key, int pr){cnt = 1; key = _key; priority = pr; left = right = NULL;}
};

struct Treap {
    Node* root;

    Treap(){root = NULL; srand(time(NULL));}

    int TreeSize(Node* T)
    {
        return T==NULL?0:T->cnt;
    }

    void UpdateCnt(Node* &T)
    {
        if(T)
        {
            T->cnt = 1 + TreeSize(T->left) + TreeSize(T->right);
        }
    }

    void LeftRotate(Node* &T)
    {
        Node* temp = T->left;
        T->left = temp->right;
        temp->right = T;
        T = temp;

        UpdateCnt(T->right);
        UpdateCnt(T);
    }

    void RightRotate(Node* &T)
    {
        Node* temp = T->right;
        T->right = temp->left;
        temp->left = T;
        T = temp;

        UpdateCnt(T->left);
        UpdateCnt(T);
    }

    void Insert(Node* &T, int _key)
    {
        if(T == NULL)
        {
            T = new Node(_key);
            return;
        }

        if(T->key > _key)
        {
            Insert(T->left, _key);

            if(T->priority < T->left->priority)
                LeftRotate(T);
        }
        else if(T->key < _key)
        {
            Insert(T->right, _key);

            if(T->priority < T->right->priority)
                RightRotate(T);
        }

        UpdateCnt(T);
    }

    void Insert(int _key)
    {
        Insert(root, _key);
    }

    void Delete(Node* &T, int _key)
    {
        if(T == NULL)
            return;

        if(T->key > _key)
            Delete(T->left, _key);
        else if(T->key < _key)
            Delete(T->right, _key);
        else
        {
            if(T->left && T->right)
            {
                if(T->left->priority > T->right->priority)
                    LeftRotate(T);
                else
                    RightRotate(T);

                Delete(T, _key);
            }
            else
            {
                Node* temp = T;

                if(T->left)
                    T = T->left;
                else
                    T = T->right;

                delete temp;
            }
        }


        UpdateCnt(T);
    }

    void Delete(int _key)
    {
        Delete(root, _key);
    }

    int Count(Node* T, int bound)
    {
        if(T == NULL)
            return 0;

        if(T->key < bound)
            return 1 + TreeSize(T->left) + Count(T->right, bound);

        return Count(T->left, bound);
    }

    int Count(int bound)
    {
        return Count(root,bound);
    }

    int FindKth(Node* T, int k)
    {
        if(TreeSize(T) < k)
            return -INF;

        int sz = 1 + TreeSize(T->left);

        if(sz == k)
            return T->key;

        if(sz < k)
            return FindKth(T->right, k-sz);

        return FindKth(T->left,k);
    }

    int FindKth(int k)
    {
        return FindKth(root, k);
    }
};

int main()
{
    // freopen("in.txt","r",stdin);

    int Q; scanf("%d",&Q);

    Treap oTreap;

    while(Q--)
    {
        char t[5];
        int p;
        scanf("%s%d",t,&p);

        if(t[0]=='I')
        {
            oTreap.Insert(p);
        }
        else if(t[0]=='D')
        {
            oTreap.Delete(p);
        }
        else if(t[0]=='K')
        {
            int v = oTreap.FindKth(p);

            if(v > -INF)
            {
                printf("%d\n",v);
            }
            else
                puts("invalid");
        }
        else
        {
            int v = oTreap.Count(p);

            printf("%d\n",v);
        }

    }

    return 0;
}

A needlessly long but Correct Splay Tree implementation

What follows is the Accepted solution of This Problem
I wish to write a shorter splay tree soon


// Splay Tree Implementation
// used for SPOJ Problem: ORDERSET

#include<stdio.h>
#include<string.h>
#include<math.h>
#include<ctype.h>
#include<assert.h>
#include<stdlib.h>
#include<time.h>

#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<iostream>
#include<algorithm>
#include<string>

using namespace std;

#define FOR(i,n) for(int i=0;i<(n);++i)
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define CLR(a,x) memset(a,(x),sizeof(a))

typedef long long LL;
typedef pair<int,int> pii;

#define INF (1<<30)

#define KEY_TYPE int

class Node
{
private:
    KEY_TYPE    key;
    int         cnt; // size of the subtree rooted here

    Node* left;
    Node* right;
    Node* parent;

    int calcTreeSize();

public:
    bool isLeftChild();
    bool isRightChild();
    bool isRoot();

    KEY_TYPE getKeyVal();
    int      getCnt();

    void Rotate();
    void Zig();
    void ZigZig();
    void ZigZag();

    bool NormalBSTInsert(Node* newNode);
    void Splay();
    void SplayUntil(Node* until);

    Node();
    Node(const KEY_TYPE _key);

    ~Node();

    Node* FindNearest(const KEY_TYPE _key);
    bool Exists(const KEY_TYPE _key);
    void Insert(KEY_TYPE _key);
    void Insert(Node* newNode);
    void Delete();
    void Delete(KEY_TYPE _key);

    void PrintInOrder();
    void PrintPreOrder();

    KEY_TYPE findKth(int k);
    int Count(KEY_TYPE _key);
};


Node* ROOT;

bool Node::isLeftChild()
{
    return (this->parent!=NULL && this->parent->left == this);
}

bool Node::isRightChild()
{
    return (this->parent!=NULL && this->parent->right == this);
}

bool Node::isRoot()
{
    return this->parent == NULL;
}

KEY_TYPE Node::getKeyVal()
{
    return this->key;
}

int Node::getCnt()
{
    return this->cnt;
}

int Node::calcTreeSize()
{
    return 1 + (this->left?this->left->cnt:0) + (this->right?this->right->cnt:0);
}

void Node::Rotate()
{
    if(isLeftChild())
    {
        bool parWasLeftChild = this->parent?this->parent->isLeftChild():false;

        parent->left = this->right;

        if(this->right)
            this->right->parent = this->parent;

        Node* newParent = this->parent->parent;

        this->right = this->parent;

        if(this->right)
            this->right->parent = this;

        this->parent = newParent;

        if(this->parent != NULL)
        {
            if(parWasLeftChild)
                this->parent->left = this;
            else
                this->parent->right = this;
        }

        this->right->cnt = this->right->calcTreeSize();
        this->cnt = this->calcTreeSize();
    }
    else if(isRightChild())
    {
        bool parWasLeftChild = this->parent?this->parent->isLeftChild():false;

        parent->right = this->left;

        if(this->left)
            this->left->parent = parent;

        Node* newParent = this->parent->parent;

        this->left = this->parent;

        if(this->left)
            this->left->parent = this;

        this->parent = newParent;

        if(this->parent != NULL)
        {
            if(parWasLeftChild)
                this->parent->left = this;
            else
                this->parent->right = this;
        }

        this->left->cnt = this->left->calcTreeSize();
        this->cnt = this->calcTreeSize();
    }
}

void Node::Zig()
{
    Rotate();
}

void Node::ZigZig()
{
    this->parent->Rotate();
    this->Rotate();
}

void Node::ZigZag()
{
    this->Rotate();
    this->Rotate();
}

void Node::Splay()
{
    while( !this->isRoot() )
    {
        if(this->parent->isRoot()) // Zig
        {
            this->Zig();
        }
        else if(this->isLeftChild() == this->parent->isLeftChild()) // ZigZig
        {
            this->ZigZig();
        }
        else // ZigZag
        {
            this->ZigZag();
        }
    }

    ROOT = this;
}

Node::Node()
{
    key     = 0;
    cnt     = 0;
    left    = NULL;
    right   = NULL;
    parent  = NULL;
}

Node::Node(const KEY_TYPE _key)
{
    key     = _key;
    cnt     = 1;
    left    = NULL;
    right   = NULL;
    parent  = NULL;
}

Node::~Node()
{

}

Node* Node::FindNearest(const KEY_TYPE _key)
{
    Node *ret = this;

    while(true)
    {

        if(ret->key < _key)
        {
            if(ret->right)
            {
                ret = ret->right;
            }
            else
                break;
        }
        else if(ret->key > _key)
        {
            if(ret->left)
            {
                ret = ret->left;
            }
            else
                break;
        }
        else
        {
            break;
        }
    }

    return ret;
}

bool Node::Exists(const KEY_TYPE _key)
{
    Node* temp = this->FindNearest(_key);
    temp->Splay();

    return ROOT->key == _key;
}

// returns false if a node with key = newNode->key already exists in the tree
bool Node::NormalBSTInsert(Node* newNode)
{
    Node* curNode = this;
    bool alreadyThere = false;

    while(true)
    {
        if(newNode->key > curNode->key)
        {
            if(curNode->right)
            {
                curNode = curNode->right;
            }
            else
            {
                curNode->right = newNode;
                newNode->parent = curNode;
                break;
            }
        }
        else if(newNode->key < curNode->key)
        {
            if(curNode->left)
            {
                curNode = curNode->left;
            }
            else
            {
                curNode->left = newNode;
                newNode->parent = curNode;
                break;
            }
        }
        else
        {
            alreadyThere = true;
            break;
        }
    }

    if(!alreadyThere)
    {
        Node* temp = newNode->parent;

        while(temp)
        {
            temp->cnt = temp->calcTreeSize();
            temp = temp->parent;
        }

        return true;
    }
    else
        return false;
}

void Node::Insert(KEY_TYPE _key)
{
    this->Insert(new Node(_key));
}

void Node::Insert(Node* newNode)
{
    if(newNode == NULL)
        return;

    if(this->NormalBSTInsert(newNode))
        newNode->Splay();
}

void Node::SplayUntil(Node* until)
{
    Node* grandParent = until->parent;

    while(this->parent != grandParent)
    {
        if(this->parent == until) // zig
        {
            this->Zig();
        }
        else if(this->isLeftChild() == this->parent->isLeftChild()) // zigzig
        {
            this->ZigZig();
        }
        else
            this->ZigZag();
    }
}

void Node::Delete()
{
    this->Splay();

    if(this->left)
    {
        Node* maxNode = this->left->FindNearest(INF);
        maxNode->SplayUntil(this->left);

        this->left->right = this->right;

        if(this->right)
            this->right->parent = this->left;

        ROOT = this->left;
        ROOT->parent = NULL;
        ROOT->cnt = ROOT->calcTreeSize();
    }
    else if(this->right)
    {
        Node* minNode = this->right->FindNearest(-INF);
        minNode->SplayUntil(this->right);

        this->right->left = this->left;

        if(this->left)
            this->left->parent = this->right;

        ROOT = this->right;
        ROOT->parent = NULL;
        ROOT->cnt = ROOT->calcTreeSize();
    }
    else
    {
        ROOT = NULL;
    }
}

void Node::Delete(KEY_TYPE _key)
{
    if(this->Exists(_key))
    {
        Node* forDelete = ROOT; // the Exists method splays the node to the root
        ROOT->Delete();
        delete forDelete;
        forDelete = NULL;
    }
}

KEY_TYPE Node::findKth(int k)
{
    if(this->cnt < k)
        return -INF;

    Node* temp = this;

    while(true)
    {
        int c = (temp->left?temp->left->cnt:0) + 1;

        if(c == k)
            break;

        if(c < k)
        {
            k -= c;
            temp = temp->right;
        }
        else
        {
            temp = temp->left;
        }
    }

    temp->Splay();
    return ROOT->key;
}

int Node::Count(KEY_TYPE _key)
{
    Node* temp = this;

    int ret = 0;

    while(true)
    {
        if(temp->key < _key)
        {
            int c = (temp->left?temp->left->cnt:0)+1;

            ret += c;

            if(temp->right)
            {
                temp = temp->right;
            }
            else
            {
                break;
            }
        }
        else if(temp->key > _key)
        {
            if(temp->left)
            {
                temp = temp->left;
            }
            else
                break;
        }
        else
        {
            ret += (temp->left?temp->left->cnt:0);
            break;
        }
    }

    temp->Splay();

    return ret;
}

void Node::PrintInOrder() // not needed for the problem
{
    if(this->left)
        this->left->PrintInOrder();

    cout<<this->getKeyVal()<<" ";

    if(this->right)
        this->right->PrintInOrder();
}

void Node::PrintPreOrder() // not needed for the problem
{
    cout<<"["<<this->getKeyVal()<<":"<<this->cnt<<"]";

    printf("(");
    if(this->left)
        this->left->PrintPreOrder();
    printf(")");

    printf("(");
    if(this->right)
        this->right->PrintPreOrder();
    printf(")");
}

void SplayTreeInsert(KEY_TYPE key)
{
    if(ROOT)
        ROOT->Insert(key);
    else
        ROOT = new Node(key);
}

int main()
{
    int Q; scanf("%d",&Q);

    ROOT = NULL;

    while(Q--)
    {
        char t[5];
        int p;
        scanf("%s%d",t,&p);

        if(t[0]=='I')
        {
            SplayTreeInsert(p);
        }
        else if(t[0]=='D')
        {
            if(ROOT)
                ROOT->Delete(p);
        }
        else if(t[0]=='K')
        {
            int v = -INF;

            if(ROOT)
                v = ROOT->findKth(p);

            if(v > -INF)
            {
                printf("%d\n",v);
            }
            else
                puts("invalid");
        }
        else
        {
            int v = 0;
            if(ROOT)
                v = ROOT->Count(p);

            printf("%d\n",v);
        }

    }

    return 0;
}