Pages

Monday, April 23, 2012

f(f(x)) = -x, A neat Puzzle!

I found this problem in some blog that I don't remember:

The Problem:
Write a C function f(x) that takes an integer x and returns another integer such that f(f(x)) = -x
You are not allowed to use any global variables or static variables or file operations.

The Solution:
Let f(x) = y, then f(y) = -x
so f(-x) = -y                          [ -y = f(f(y)) = f(-x) ]
and f(-y) = x                          [ x = f(f(-x)) = f(-y) ]

f(x) creates a cycle of integers: (x) ==> (y) ==> (-x) ==> (-y) ==> (x)


So, we have to partition the set of natural numbers into cycles of length 4. But, 0 is a special case. f(0) = 0.
For example, one such partition can be:
(1) ==> (2) ==> (-1) ==> (-2) ==> (1)
(3) ==> (4) ==> (-3) ==> (-4) ==> (3)
(5) ==> (6) ==> (-5) ==> (-6) ==> (5)
... and so on

[Edit: the following portion has been added after misof correctly pointed out my mistake]

It is impossible to write a C function that solves the problem completely. We can represent 2^B numbers using B bits. The range is [ -(2^B) to (2^B)-1 ]. Note that, the 2^B doesn't exist in the domain, so, the problem is unsolvable for -(2^B). In fact, there will be another number for which we can't solve it. Let's see why.

Note that, 0 can't be part of a cycle of length 4.

Let, f(0) = 0. That leaves (2^B)-1 numbers to partition into cycles of length 4. But [(2^B)-1]%4 = 3. So, there will be 3 numbers left to be partitioned. We can create a partial cycle with these 3 numbers so that exactly one of them will be matched with it's negative. One of these three numbers will be -(2^B). So, there will be 2 unlucky numbers.

Let, f(0) = x, for some x != 0. If we want f to work for 0, f(x) must map x to 0. We get a cycle of length 2 where x is unlucky. It's better to chose x = -(2^B). That leaves (2^B)-2 numbers. After partitioning into cycles of length 4, 2 numbers will be left: y and -y. We can solve for exactly one of them by assigning f(y) = -y and f(-y) = -y. or vice versa. So, one of these 2 will be unlucky.

In both case, 2 numbers are unlucky!


But the following function partially works for all integers except -(2^31) and (2^31)-1:

int f(int x){
    if(x==0) return 0;

    if(x>0){
        if(x%2==1)
            return x+1;
        else
            return -(x-1);
    }else{
        if((-x)%2==1)
            return x-1;
        else
            return -(x+1);
    }
}

Friday, March 9, 2012

My Practice Log (Google Spreadsheet)

Here is the Google spreadsheet that I try to maintain.

This helps me to track my progress.
Ever since I have started this spreadsheet I didn't have any abrupt-long-break of problem solving.
Whenever I write "NOTHING" in the "What did I do?" column, I feel guilty! :-)


Saturday, January 28, 2012

A probable mistake while overloading the '<' operator

I was not aware of the following basic thing. See the code segment bellow:


struct T
{
    int a, b;
    
    T(){}
    T(int _a, int _b){a=_a;b=_b;}
    
    bool operator < (T obj) const
    {
        return a < obj.a;
    }
};

int main()
{
 set<T> S;
 
 S.insert(T(0,0));
 S.insert(T(0,1));
 
 cout<<S.size()<<"\n";
 
    return 0;
}
The output of the given code is: 1 (I was surprised!)
The overloaded operator '<' can not distinguish between T(0,0) and T(0,1) (pretty basic.. huh!)

I learnt this when writing the State class for a Dijkstra's code. I was using STL set as the priority queue. Some nodes were "mysteriously" vanishing!

The overloaded operator should look like this:

    bool operator < (T obj) const
    {
        if(a!=obj.a)
            return a < obj.a;
        else
            return b < obj.b;
    }
Although this doesn't matter if you use a STL priority_queue. It can store more than one copy of the 'same' thing. But sometimes you need to use set, specially when deleting-nodes from the priority queue is important.

Friday, January 27, 2012

SPOJ ORDERSET (With Treap)

I LOVE TREAP !!


Honestly! Treap is the most beautiful Data Structure I know!
In my previous post I solved This Problem with Splay Tree which was more than 500 lines of code and took 5.98 seconds to run on SPOJ.

I solved the same problem with Treap this time and it's only a little over 200 lines that runs in 4.42 seconds on SPOJ.

The following implementation is based on the code given Here.


// SPOJ: ORDERSET
// using Treap
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<ctype.h>
#include<stdlib.h>
#include<time.h>
#include<assert.h>

#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<iostream>
#include<algorithm>
#include<string>

using namespace std;

#define FOR(i,n) for(int i=0;i<(n);++i)
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define CLR(a,x) memset(a,(x),sizeof(a))

#define INF (1<<30)

typedef long long LL;
typedef pair<int,int> pii;

struct Node {
    int key;
    int cnt;
    int priority;

    Node *left, *right;

    Node(){cnt = 0; priority = 0; left = right = NULL;}
    Node(int _key){cnt = 1; key = _key; priority = rand(); left = right = NULL;}
    Node(int _key, int pr){cnt = 1; key = _key; priority = pr; left = right = NULL;}
};

struct Treap {
    Node* root;

    Treap(){root = NULL; srand(time(NULL));}

    int TreeSize(Node* T)
    {
        return T==NULL?0:T->cnt;
    }

    void UpdateCnt(Node* &T)
    {
        if(T)
        {
            T->cnt = 1 + TreeSize(T->left) + TreeSize(T->right);
        }
    }

    void LeftRotate(Node* &T)
    {
        Node* temp = T->left;
        T->left = temp->right;
        temp->right = T;
        T = temp;

        UpdateCnt(T->right);
        UpdateCnt(T);
    }

    void RightRotate(Node* &T)
    {
        Node* temp = T->right;
        T->right = temp->left;
        temp->left = T;
        T = temp;

        UpdateCnt(T->left);
        UpdateCnt(T);
    }

    void Insert(Node* &T, int _key)
    {
        if(T == NULL)
        {
            T = new Node(_key);
            return;
        }

        if(T->key > _key)
        {
            Insert(T->left, _key);

            if(T->priority < T->left->priority)
                LeftRotate(T);
        }
        else if(T->key < _key)
        {
            Insert(T->right, _key);

            if(T->priority < T->right->priority)
                RightRotate(T);
        }

        UpdateCnt(T);
    }

    void Insert(int _key)
    {
        Insert(root, _key);
    }

    void Delete(Node* &T, int _key)
    {
        if(T == NULL)
            return;

        if(T->key > _key)
            Delete(T->left, _key);
        else if(T->key < _key)
            Delete(T->right, _key);
        else
        {
            if(T->left && T->right)
            {
                if(T->left->priority > T->right->priority)
                    LeftRotate(T);
                else
                    RightRotate(T);

                Delete(T, _key);
            }
            else
            {
                Node* temp = T;

                if(T->left)
                    T = T->left;
                else
                    T = T->right;

                delete temp;
            }
        }


        UpdateCnt(T);
    }

    void Delete(int _key)
    {
        Delete(root, _key);
    }

    int Count(Node* T, int bound)
    {
        if(T == NULL)
            return 0;

        if(T->key < bound)
            return 1 + TreeSize(T->left) + Count(T->right, bound);

        return Count(T->left, bound);
    }

    int Count(int bound)
    {
        return Count(root,bound);
    }

    int FindKth(Node* T, int k)
    {
        if(TreeSize(T) < k)
            return -INF;

        int sz = 1 + TreeSize(T->left);

        if(sz == k)
            return T->key;

        if(sz < k)
            return FindKth(T->right, k-sz);

        return FindKth(T->left,k);
    }

    int FindKth(int k)
    {
        return FindKth(root, k);
    }
};

int main()
{
    // freopen("in.txt","r",stdin);

    int Q; scanf("%d",&Q);

    Treap oTreap;

    while(Q--)
    {
        char t[5];
        int p;
        scanf("%s%d",t,&p);

        if(t[0]=='I')
        {
            oTreap.Insert(p);
        }
        else if(t[0]=='D')
        {
            oTreap.Delete(p);
        }
        else if(t[0]=='K')
        {
            int v = oTreap.FindKth(p);

            if(v > -INF)
            {
                printf("%d\n",v);
            }
            else
                puts("invalid");
        }
        else
        {
            int v = oTreap.Count(p);

            printf("%d\n",v);
        }

    }

    return 0;
}

A needlessly long but Correct Splay Tree implementation

What follows is the Accepted solution of This Problem
I wish to write a shorter splay tree soon


// Splay Tree Implementation
// used for SPOJ Problem: ORDERSET

#include<stdio.h>
#include<string.h>
#include<math.h>
#include<ctype.h>
#include<assert.h>
#include<stdlib.h>
#include<time.h>

#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<iostream>
#include<algorithm>
#include<string>

using namespace std;

#define FOR(i,n) for(int i=0;i<(n);++i)
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define CLR(a,x) memset(a,(x),sizeof(a))

typedef long long LL;
typedef pair<int,int> pii;

#define INF (1<<30)

#define KEY_TYPE int

class Node
{
private:
    KEY_TYPE    key;
    int         cnt; // size of the subtree rooted here

    Node* left;
    Node* right;
    Node* parent;

    int calcTreeSize();

public:
    bool isLeftChild();
    bool isRightChild();
    bool isRoot();

    KEY_TYPE getKeyVal();
    int      getCnt();

    void Rotate();
    void Zig();
    void ZigZig();
    void ZigZag();

    bool NormalBSTInsert(Node* newNode);
    void Splay();
    void SplayUntil(Node* until);

    Node();
    Node(const KEY_TYPE _key);

    ~Node();

    Node* FindNearest(const KEY_TYPE _key);
    bool Exists(const KEY_TYPE _key);
    void Insert(KEY_TYPE _key);
    void Insert(Node* newNode);
    void Delete();
    void Delete(KEY_TYPE _key);

    void PrintInOrder();
    void PrintPreOrder();

    KEY_TYPE findKth(int k);
    int Count(KEY_TYPE _key);
};


Node* ROOT;

bool Node::isLeftChild()
{
    return (this->parent!=NULL && this->parent->left == this);
}

bool Node::isRightChild()
{
    return (this->parent!=NULL && this->parent->right == this);
}

bool Node::isRoot()
{
    return this->parent == NULL;
}

KEY_TYPE Node::getKeyVal()
{
    return this->key;
}

int Node::getCnt()
{
    return this->cnt;
}

int Node::calcTreeSize()
{
    return 1 + (this->left?this->left->cnt:0) + (this->right?this->right->cnt:0);
}

void Node::Rotate()
{
    if(isLeftChild())
    {
        bool parWasLeftChild = this->parent?this->parent->isLeftChild():false;

        parent->left = this->right;

        if(this->right)
            this->right->parent = this->parent;

        Node* newParent = this->parent->parent;

        this->right = this->parent;

        if(this->right)
            this->right->parent = this;

        this->parent = newParent;

        if(this->parent != NULL)
        {
            if(parWasLeftChild)
                this->parent->left = this;
            else
                this->parent->right = this;
        }

        this->right->cnt = this->right->calcTreeSize();
        this->cnt = this->calcTreeSize();
    }
    else if(isRightChild())
    {
        bool parWasLeftChild = this->parent?this->parent->isLeftChild():false;

        parent->right = this->left;

        if(this->left)
            this->left->parent = parent;

        Node* newParent = this->parent->parent;

        this->left = this->parent;

        if(this->left)
            this->left->parent = this;

        this->parent = newParent;

        if(this->parent != NULL)
        {
            if(parWasLeftChild)
                this->parent->left = this;
            else
                this->parent->right = this;
        }

        this->left->cnt = this->left->calcTreeSize();
        this->cnt = this->calcTreeSize();
    }
}

void Node::Zig()
{
    Rotate();
}

void Node::ZigZig()
{
    this->parent->Rotate();
    this->Rotate();
}

void Node::ZigZag()
{
    this->Rotate();
    this->Rotate();
}

void Node::Splay()
{
    while( !this->isRoot() )
    {
        if(this->parent->isRoot()) // Zig
        {
            this->Zig();
        }
        else if(this->isLeftChild() == this->parent->isLeftChild()) // ZigZig
        {
            this->ZigZig();
        }
        else // ZigZag
        {
            this->ZigZag();
        }
    }

    ROOT = this;
}

Node::Node()
{
    key     = 0;
    cnt     = 0;
    left    = NULL;
    right   = NULL;
    parent  = NULL;
}

Node::Node(const KEY_TYPE _key)
{
    key     = _key;
    cnt     = 1;
    left    = NULL;
    right   = NULL;
    parent  = NULL;
}

Node::~Node()
{

}

Node* Node::FindNearest(const KEY_TYPE _key)
{
    Node *ret = this;

    while(true)
    {

        if(ret->key < _key)
        {
            if(ret->right)
            {
                ret = ret->right;
            }
            else
                break;
        }
        else if(ret->key > _key)
        {
            if(ret->left)
            {
                ret = ret->left;
            }
            else
                break;
        }
        else
        {
            break;
        }
    }

    return ret;
}

bool Node::Exists(const KEY_TYPE _key)
{
    Node* temp = this->FindNearest(_key);
    temp->Splay();

    return ROOT->key == _key;
}

// returns false if a node with key = newNode->key already exists in the tree
bool Node::NormalBSTInsert(Node* newNode)
{
    Node* curNode = this;
    bool alreadyThere = false;

    while(true)
    {
        if(newNode->key > curNode->key)
        {
            if(curNode->right)
            {
                curNode = curNode->right;
            }
            else
            {
                curNode->right = newNode;
                newNode->parent = curNode;
                break;
            }
        }
        else if(newNode->key < curNode->key)
        {
            if(curNode->left)
            {
                curNode = curNode->left;
            }
            else
            {
                curNode->left = newNode;
                newNode->parent = curNode;
                break;
            }
        }
        else
        {
            alreadyThere = true;
            break;
        }
    }

    if(!alreadyThere)
    {
        Node* temp = newNode->parent;

        while(temp)
        {
            temp->cnt = temp->calcTreeSize();
            temp = temp->parent;
        }

        return true;
    }
    else
        return false;
}

void Node::Insert(KEY_TYPE _key)
{
    this->Insert(new Node(_key));
}

void Node::Insert(Node* newNode)
{
    if(newNode == NULL)
        return;

    if(this->NormalBSTInsert(newNode))
        newNode->Splay();
}

void Node::SplayUntil(Node* until)
{
    Node* grandParent = until->parent;

    while(this->parent != grandParent)
    {
        if(this->parent == until) // zig
        {
            this->Zig();
        }
        else if(this->isLeftChild() == this->parent->isLeftChild()) // zigzig
        {
            this->ZigZig();
        }
        else
            this->ZigZag();
    }
}

void Node::Delete()
{
    this->Splay();

    if(this->left)
    {
        Node* maxNode = this->left->FindNearest(INF);
        maxNode->SplayUntil(this->left);

        this->left->right = this->right;

        if(this->right)
            this->right->parent = this->left;

        ROOT = this->left;
        ROOT->parent = NULL;
        ROOT->cnt = ROOT->calcTreeSize();
    }
    else if(this->right)
    {
        Node* minNode = this->right->FindNearest(-INF);
        minNode->SplayUntil(this->right);

        this->right->left = this->left;

        if(this->left)
            this->left->parent = this->right;

        ROOT = this->right;
        ROOT->parent = NULL;
        ROOT->cnt = ROOT->calcTreeSize();
    }
    else
    {
        ROOT = NULL;
    }
}

void Node::Delete(KEY_TYPE _key)
{
    if(this->Exists(_key))
    {
        Node* forDelete = ROOT; // the Exists method splays the node to the root
        ROOT->Delete();
        delete forDelete;
        forDelete = NULL;
    }
}

KEY_TYPE Node::findKth(int k)
{
    if(this->cnt < k)
        return -INF;

    Node* temp = this;

    while(true)
    {
        int c = (temp->left?temp->left->cnt:0) + 1;

        if(c == k)
            break;

        if(c < k)
        {
            k -= c;
            temp = temp->right;
        }
        else
        {
            temp = temp->left;
        }
    }

    temp->Splay();
    return ROOT->key;
}

int Node::Count(KEY_TYPE _key)
{
    Node* temp = this;

    int ret = 0;

    while(true)
    {
        if(temp->key < _key)
        {
            int c = (temp->left?temp->left->cnt:0)+1;

            ret += c;

            if(temp->right)
            {
                temp = temp->right;
            }
            else
            {
                break;
            }
        }
        else if(temp->key > _key)
        {
            if(temp->left)
            {
                temp = temp->left;
            }
            else
                break;
        }
        else
        {
            ret += (temp->left?temp->left->cnt:0);
            break;
        }
    }

    temp->Splay();

    return ret;
}

void Node::PrintInOrder() // not needed for the problem
{
    if(this->left)
        this->left->PrintInOrder();

    cout<<this->getKeyVal()<<" ";

    if(this->right)
        this->right->PrintInOrder();
}

void Node::PrintPreOrder() // not needed for the problem
{
    cout<<"["<<this->getKeyVal()<<":"<<this->cnt<<"]";

    printf("(");
    if(this->left)
        this->left->PrintPreOrder();
    printf(")");

    printf("(");
    if(this->right)
        this->right->PrintPreOrder();
    printf(")");
}

void SplayTreeInsert(KEY_TYPE key)
{
    if(ROOT)
        ROOT->Insert(key);
    else
        ROOT = new Node(key);
}

int main()
{
    int Q; scanf("%d",&Q);

    ROOT = NULL;

    while(Q--)
    {
        char t[5];
        int p;
        scanf("%s%d",t,&p);

        if(t[0]=='I')
        {
            SplayTreeInsert(p);
        }
        else if(t[0]=='D')
        {
            if(ROOT)
                ROOT->Delete(p);
        }
        else if(t[0]=='K')
        {
            int v = -INF;

            if(ROOT)
                v = ROOT->findKth(p);

            if(v > -INF)
            {
                printf("%d\n",v);
            }
            else
                puts("invalid");
        }
        else
        {
            int v = 0;
            if(ROOT)
                v = ROOT->Count(p);

            printf("%d\n",v);
        }

    }

    return 0;
}

Thursday, December 29, 2011

SGU-103 Traffic Lights

Problem Link

The Problem in short:
Given an undirected weighted graph where each node contains a traffic light. The light shows two colors: Blue and Purple. Blue is visible for t_b seconds and Purple is visible for t_p seconds. This keeps repeating periodically. These two values may differ for different nodes. Cars can cross an edge (u,v) iff both u and v has same color at the moment when the car departs from u to v.

Now, given a source S and a destination T and the initial state of every node (there current color and how much time is remaining for that color), calculate the minimum time to reach from S to T.

Solution:
Run Dijkstra's from S. The tricky part is: given a time t and a node u such that, the current time is t and the car is at u, when will the light at u and the light at v show the same color ?

#include<stdio.h>
#include<string.h>
#include<math.h>
#include<ctype.h>
#include<assert.h>
#include<stdlib.h>
#include<time.h>

#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<iostream>
#include<algorithm>
#include<string>

using namespace std;

#define FOR(i,n) for(int i=0;i<(n);++i)
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define CLR(a,x) memset(a,(x),sizeof(a))

typedef long long LL;
typedef pair<int,int> pii;

struct State {
    int at;
    int when;

    State(){}
    State(int _at, int _when){at = _at; when = _when;}

    bool operator < (const State other) const
    {
        return when > other.when;
    }
};

const int INF = (1<<29);

vector<int> adj[305];
int nodes, edges;
int cost[305][305];
int from,to;
int dist[305], parent[305];

int remaining[305], initColor[305], duration[305][2];

typedef pair<int,int> colorInfo;

colorInfo getColorInfo(int u, int t)
{
    if(t < remaining[u])
        return colorInfo(initColor[u],remaining[u]-t);

    int M = duration[u][0]+duration[u][1];
    int r = (t-remaining[u])%M;
    if(r < duration[u][1-initColor[u]])
        return colorInfo(1-initColor[u],duration[u][1-initColor[u]]-r);
    else
        return colorInfo(initColor[u],duration[u][initColor[u]]-(r-duration[u][1-initColor[u]]));
}

void getNext4InterestingTime(int u, int t, vector<int>& VT)
{
    VT.push_back(t);
    colorInfo info = getColorInfo(u,t);
    t += info.second;
    VT.push_back(t);

    FOR(i,4)
    {
        if(i%2==0)
            t += duration[u][1-info.first];
        else
            t += duration[u][info.first];

        VT.push_back(t);
    }
}

int calcTime(int u, int v, int t) // currently at u, current time t, want to go to v
{
    vector<int> interestingTime;

    getNext4InterestingTime(u,t,interestingTime);
    getNext4InterestingTime(v,t,interestingTime);

    sort(interestingTime.begin(),interestingTime.end());

    FOR(i,interestingTime.size())
    {
        if(getColorInfo(u,interestingTime[i]).first == getColorInfo(v,interestingTime[i]).first)
            return interestingTime[i];
    }

    return -1;
}

int solve()
{
    REP(i,1,nodes)
        dist[i] = INF;
    dist[from] = 0;
    parent[from] = from;

    State source(from,0);

    priority_queue<State> pq;
    pq.push(source);

    while(pq.empty()==false)
    {
        State u = pq.top();
        pq.pop();

        if(u.when > dist[u.at])
            continue;

        FOR(i,adj[u.at].size())
        {
            int x = adj[u.at][i];
            int y = calcTime(u.at,x,u.when);
            if(y == -1) continue;

            y += cost[u.at][x];
            if(dist[x] > y)
            {
                parent[x] = u.at;
                dist[x] = y;
                pq.push(State(x,y));
            }
        }
    }

    return dist[to];
}

int main()
{
    scanf("%d%d",&from,&to);
    scanf("%d%d",&nodes,&edges);

    FOR(i,nodes)
    {
        char s[5];
        scanf("%s%d%d%d",s,&remaining[i+1],&duration[i+1][0],&duration[i+1][1]);

        if(s[0]=='B')
            initColor[i+1] = 0;
        else
            initColor[i+1] = 1;
    }

    FOR(e,edges)
    {
        int u,v,w; scanf("%d%d%d",&u,&v,&w);
        adj[u].push_back(v);
        adj[v].push_back(u);
        cost[u][v] = cost[v][u] = w;
    }

    int ans = solve();

    if(ans >= INF)
    {
        puts("0");
    }
    else
    {
        vector<int> path;

        while(1)
        {
            path.push_back(to);
            if(to==from)
                break;
            to = parent[to];
        }

        cout<<ans<<"\n";

        for(int i=path.size()-1;i>=0;--i)
        {
            printf("%d",path[i]);
            if(i>0)
                printf(" ");
        }

        puts("");
    }

    return 0;
}

Friday, July 29, 2011

SPOJ :: KOICOST

Problem Link

The Problem in Short:
Given an undirected weighted graph.
There is a function Cost(u,v), which is defined as follows:
While there is a path between vertex u and v, delete the edge with the smallest weight. Cost(u,v) is the sum of the weights of the edges that were deleted in this process.
your task is to calculate the sum of Cost(u,v) for all pairs of vertices u and v, where u < v
The Solution:

To describe the solution we'll need a few notations:
bound(x,y) = the last edge, that has to be deleted if we want to disconnect vertex x and y by the process described in the problem statement.
f(u,v) = number of pair of vertices (x,y) such that, bound(x,y) = edge (u,v)
w(u,v) = weight of the edge (u,v)
csum(u,v) = cumulative sum of weights of all edge (u',v') such that w(u',v') <= w(u,v)
so, result = sum over ( csum(u,v)*f(u,v) ) for all edge (u,v)


Now, the problem is to calculate f(u,v) efficiently.

Consider the following process which is the reverse of the process described in the problem statement.

Lets assume that, initially all the vertices are disconnected. We add the edges one by one in the decreasing order of their weights.

define, g(u,v) = the number of pair of vertices (x,y) such that, (x,y) becomes connected when we add the edge (u,v) in the reverse process.

Note that, f(u,v) = g(u,v) (the proof is left to the reader)

Here is the algorithm that calculates g(u,v) for all the edges.

1. sort all the edges in the decreasing order of their weights
2. create a disjoint set of for all the nodes. (yes, we'll need the disjoint-set data structure)
3. in the root of each disjoint set, store the size of that set.
let's denote by sz(u), the size of the disjoint set that vertex u is in.

4. For each edge (u,v):
if u and v are connected already (i.e. in the same set)
f(u,v) = g(u,v) = 0
else:
f(u,v) = g(u,v) = sz(u) * sz(v)
result = result + f(u,v)*csum(u,v)
disjoint_set_union(u,v)