00001
00002
00003
00004
#include "DisjointSet.hpp"
00005
00006
00007
00008 DisjointSet_Base::DisjointSet_Base(
const Node _graph):graph(_graph){
00009
00010
nbClass =
graph.
getNodeSize();
00011
int i = 0;
00012
for(
Node_Iterator it =
graph.
getNodeIterator(); it.
hasNext();){
00013
00014
Node node = it.
next();
00015
00016 node.
setRank(i);
00017
00018
vNode.push_back(node);
00019
00020
parent.push_back(-1);
00021
00022
00023
int sz = node.
getLeafSize();
00024
size.push_back(sz);
00025 i++;
00026 }
00027 }
00028
00029
00030
00031 int DisjointSet_Base::getParentRank(
int i){
00032
00033
if(
parent[i] < 0){
00034
return i;
00035 }
else {
00036
00037
return (
parent[i] =
getParentRank(
parent[i]));
00038 }
00039 }
00040
00041
00042
00043 Node DisjointSet_Base::mergeDisjointSet(
Node n1,
Node n2){
00044
00045
int c1 =
getParentRank(n1.
getRank());
00046
int c2 = getParentRank(n2.
getRank());
00047
00048
if(c1 != c2){
00049
00050
Node ret = Node::CreateLeaf();
00051
if((-1 ==
parent[c1]) && (-1 ==
parent[c2])){
00052
00053
00054
parent[c1] +=
parent[c2];
00055 parent[c2] = c1;
00056
00057
size[c1] +=
size[c2];
00058 size[c2] = 0;
00059
00060
Node cluster = Node::CreateGraph(
"CLUSTER");
00061 cluster.
addNode(n1);
00062 cluster.
addNode(n2);
00063
00064
table[parent[c2]]= cluster;
00065 ret = cluster;
00066 }
else if(
parent[c1] <=
parent[c2]){
00067
00068
if(parent[c2] == -1){
00069
00070
Node n =
table.find(c1)->second;
00071 n.
addNode(n2);
00072 ret = n;
00073 }
else {
00074
00075
Node n1 =
table.find(c1)->second;
00076
table_iter it =
table.find(c2);
00077
Node n2 = it->second;
00078
00079 n1.
merge(n2);
00080
00081
table.erase(it);
00082 ret = n1;
00083
00084
00085 }
00086
00087 parent[c1] += parent[c2];
00088 parent[c2] = c1;
00089
00090
size[c1] +=
size[c2];
00091 size[c2] = 0;
00092 }
else {
00093
if(parent[c1] == -1){
00094
00095
Node n =
table.find(c2)->second;
00096 n.
addNode(n1);
00097 ret = n;
00098 }
else {
00099
00100
table_iter it =
table.find(c1);
00101
Node n1 = it->second;
00102
Node n2 =
table.find(c2)->second;
00103
00104 n2.
merge(n1);
00105
00106
table.erase(it);
00107 ret = n2;
00108 }
00109
00110 parent[c2] += parent[c1];
00111 parent[c1] = c2;
00112
00113
size[c2] +=
size[c1];
00114 size[c1] = 0;
00115 }
00116
00117
nbClass--;
00118
return ret;
00119 }
00120 }
00121 string
DisjointSet_Base::toString(){
00122 std::ostringstream o;
00123 o <<
"Classes " << endl;
00124
for(
int i = 0; i <
parent.size(); i++){
00125 o << i <<
"\t";
00126 }
00127 o << endl;
00128
for(
int i = 0; i <
parent.size(); i++){
00129 o << i <<
"\t" <<
parent[i] <<
"\n";
00130 }
00131 o << endl;
00132 o <<
"Classes number : " <<
nbClass;
00133
return o.str();
00134 }
00135 bool DisjointSet_Base::sameDisjointSet(
Node n1,
Node n2){
00136
00137
int c1 =
getParentRank(n1.
getRank());
00138
int c2 = getParentRank(n2.
getRank());
00139
return c1 == c2;
00140 }
00141
00142
00143 Node DisjointSet_Base::makeNode(string label){
00144
Trace trace(
"makeNode", channelGraph);
00145
00146
Node ret = Node::CreateGraph(label);
00147
00148 vector<Node> parentNode;
00149
for(
int i = 0; i <
parent.size(); i++){
00150
00151
if(
parent[i] < 0){
00152
00153
if(
parent[i] == -1){
00154
00155
Node cluster = Node::CreateGraph(
"CLUSTER");
00156 cluster.
addNode(
vNode[i]);
00157 cluster.
setPixel(
vNode[i].getPixel());
00158 cluster.
setWeight(
vNode[i].getWeight());
00159 ret.
addNode(cluster);
00160
00161 parentNode.push_back(cluster);
00162 }
else {
00163
00164
Node cluster =
table.find(i)->second;
00165 ret.
addNode(cluster);
00166
00167 parentNode.push_back(cluster);
00168 }
00169 }
else {
00170
int r =
getParentRank(i);
00171
Node cluster =
table.find(r)->second;
00172
00173 parentNode.push_back(cluster);
00174 }
00175 }
00176
00177
00178
00179
00180
00181
typedef pair<int, int> aPair;
00182 map< aPair , Edge>
table;
00183
typedef map< aPair , Edge>::iterator
table_iter;
00184
00185
for(
int i = 0; i <
edge.size(); i++){
00186
00187
Edge aEdge =
edge[i];
00188
00189
int r1 = aEdge.
getSRC().
getRank();
00190
int r2 = aEdge.
getDST().
getRank();
00191
int pr1 =
getParentRank(r1);
00192
int pr2 = getParentRank(r2);
00193
if(pr1 == pr2){
00194
00195
00196
continue;
00197 }
00198
00199 aPair pair1 = make_pair(pr1, pr2);
00200 aPair pair2 = make_pair(pr2, pr1);
00201 table_iter b1 = table.find(pair1);
00202 table_iter b2 = table.find(pair2);
00203 table_iter e = table.end();
00204
if((table.size() == 0) || ((b1 == e) && (b2 == e)) ){
00205
00206
Node p1 = parentNode[r1];
00207
Node p2 = parentNode[r2];
00208
float weight = aEdge.
getWeight();
00209
Edge newEdge(p1, p2, weight);
00210
00211
int x0 = aEdge.
getX0();
00212
int y0 = aEdge.
getY0();
00213
int x1 = aEdge.
getX1();
00214
int y1 = aEdge.
getY1();
00215 newEdge.
setCoord(x0, y0, x1, y1);
00216 ret.
addEdge(newEdge);
00217
00218 table.insert(make_pair(pair1, newEdge));
00219 }
00220 }
00221
return ret;
00222 }
00223 int DisjointSet_Base::getParentSize(
Node node){
00224
int rank =
getParentRank(node.
getRank());
00225
return size[rank];
00226 }
00227 Node DisjointSet_Base::getParent(
Node node){
00228
int rank =
getParentRank(node.
getRank());
00229
if(
parent[rank] == -1){
00230
00231
return node;
00232 }
else {
00233
00234
table_iter it =
table.find(rank);
00235
return it->second;
00236 }
00237 }
00238
00239
00240
00241
00242 Kruskal::Kruskal(
Node _graph):
DisjointSet_Base(_graph){}
00243 Node Kruskal::makeDisjointSet(){
00244
Trace trace(
"makeDisjointSet", channelGraph);
00245 trace.
print(
"sort");
00246
00247 graph.
sortEdge();
00248
00249
for(
Edge_Iterator it = graph.
getEdgeIterator(); it.
hasNext();){
00250
Edge e = it.
next();
00251
00252
Node n1 = e.
getSRC();
00253
Node n2 = e.
getDST();
00254
00255
if(!sameDisjointSet(n1, n2)){
00256
00257
00258
Node n = mergeDisjointSet(n1, n2);
00259
00260 n.
addEdge(e);
00261 }
00262 }
00263
Node ret = makeNode(
"KRUSKAL");
00264
return ret;
00265 }
00266
00267
00268
00269
00270 MInt::MInt(
Node _graph,
float _tau):
00271
DisjointSet_Base(_graph),
00272 tau(_tau){
00273 }
00274 Node MInt::makeDisjointSet(){
00275
Trace trace(
"MInt", channelGraph);
00276
00277 {
00278 ostringstream o;
00279 o <<
"Node size " << graph.
getNodeSize() <<
" edge " << graph.
getEdgeSize() <<
" tau " <<
tau;
00280 trace.
print(o.str());
00281 }
00282
00283 trace.
print(
"sort");
00284 graph.
sortEdge();
00285
00286
00287
int i = 0;
00288
for(
Edge_Iterator it = graph.
getEdgeIterator(); it.
hasNext();){
00289
Edge e = it.
next();
00290
Node n1 = e.
getSRC();
00291
Node n2 = e.
getDST();
00292
if(!sameDisjointSet(e.
getSRC(), e.
getDST())){
00293
bool merge =
false;
00294
float edgeWeight = e.
getWeight();
00295 {
00296
int size1 = getParentSize(n1);
00297
int size2 = getParentSize(n2);
00298
00299
float f1 = getParent(n1).
getWeight();
00300
float f2 = getParent(n2).getWeight();
00301
00302
float k1 = (
tau / size1);
00303
float k2 = (
tau / size2);
00304
00305
float mint1 = f1 + k1;
00306
float mint2 = f2 + k2;
00307
00308
float mint = mint1;
00309
if(mint2 < mint){
00310 mint = mint2;
00311 }
00312
if(edgeWeight <= mint){
00313 merge =
true;
00314 }
00315 }
00316
if(merge){
00317
int size1 = getParentSize(n1);
00318
int size2 = getParentSize(n2);
00319
Color c1 = getParent(n1).
getPixel().
getColor();
00320
Color c2 = getParent(n2).getPixel().
getColor();
00321
Node n = mergeDisjointSet(n1, n2);
00322
Color newColor = Color::getAverageColor(c1, size1, c2, size2);
00323 n.
getPixel().
setColor(newColor);
00324 n.
setWeight(edgeWeight);
00325 n.
addEdge(e);
00326 }
else {
00327 edge.push_back(e);
00328 }
00329 }
00330 i++;
00331 }
00332
Node ret = makeNode(
"SEGMENT");
00333
return ret;
00334 }
00335
00336
00337
00338
00339 Segment::Segment(
Node _graph,
float _k):
00340
DisjointSet_Base(_graph),
00341 k(_k){
00342 }
00343 Node Segment::makeDisjointSet(){
00344
Trace trace(
"makeDisjointSet", channelCluster);
00345 {
00346 ostringstream o;
00347 o <<
"Node size " << graph.
getNodeSize() <<
" edge " << graph.
getEdgeSize();
00348 trace.
print(o.str());
00349 }
00350 {
00351
Trace trace(
"sort edge", channelCluster);
00352 graph.
sortEdge(
edgeWeightMin,
edgeWeightMax);
00353
edgeWeightDelta =
edgeWeightMax -
edgeWeightMin;
00354 ostringstream o;
00355 o <<
"Delta " <<
edgeWeightDelta;
00356 trace.
print(o.str());
00357 }
00358
00359
00360
00361
00362
00363
00364
00365
00366
int i = 0;
00367
for(
Edge_Iterator it = graph.
getEdgeIterator(); it.
hasNext();){
00368
Edge e = it.
next();
00369
Node n1 = e.
getSRC();
00370
Node n2 = e.
getDST();
00371
if(!sameDisjointSet(e.
getSRC(), e.
getDST())){
00372
bool merge =
false;
00373
float edgeWeight = e.
getWeight();
00374 {
00375
bool normalize =
true;
00376
00377
int size1 = getParentSize(n1);
00378
int size2 = getParentSize(n2);
00379
int size3 = size2 + size1;
00380
00381
float f1 = getParent(n1).
getWeight();
00382
float f2 = getParent(n2).getWeight();
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
float k1 = (
k *
edgeWeightDelta) / size1;
00417
float k2 = (
k *
edgeWeightDelta) / size2;
00418
00419 f1 = f1 / size1 + k1;
00420 f2 = f2 / size2 + k2;
00421
00422
float estimation = (f1 + f2 + edgeWeight) / 3;
00423
00424
00425
if(edgeWeight <= estimation){
00426 merge =
true;
00427 }
else {
00428 merge =
false;
00429 }
00430
if(merge){
00431
00432
int p1 = getParentRank(n1.
getRank());
00433
int p2 = getParentRank(n2.
getRank());
00434
Color c1 = getParent(n1).getPixel().getColor();
00435
Color c2 = getParent(n2).getPixel().
getColor();
00436
int size1 = getParentSize(n1);
00437
int size2 = getParentSize(n2);
00438
Node n = mergeDisjointSet(n1, n2);
00439
Color newColor = Color::getAverageColor(c1, size1, c2, size2);
00440 n.
getPixel().
setColor(newColor);
00441
int _p1 = getParentRank(n1.
getRank());
00442
int _p2 = getParentRank(n2.
getRank());
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
float sum = n.
getWeight() + edgeWeight;
00457 n.
setWeight(sum);
00458 n.
addEdge(e);
00459 }
else {
00460 edge.push_back(e);
00461
00462
00463
00464
00465
00466
00467
00468 }
00469 }
00470 }
00471 i++;
00472 }
00473
Node ret = makeNode(
"SEGMENT");
00474
00475 {
00476 ostringstream o;
00477 o <<
"Node size " << ret.
getNodeSize() <<
" edge " << ret.
getEdgeSize();
00478 trace.
print(o.str());
00479 }
00480
00481
return ret;
00482 }
00483
00484
00485
00486 MeanShiftFusion::MeanShiftFusion(
Node _graph,
float _radius):
00487
DisjointSet_Base(_graph),
00488 radius(_radius){
00489 }
00490 Node MeanShiftFusion::makeDisjointSet(){
00491
Trace trace(
"Fusion", channelCluster);
00492 {
00493 ostringstream o;
00494 o <<
"Node size " << graph.
getNodeSize() <<
" edge " << graph.
getEdgeSize();
00495 trace.
print(o.str());
00496 }
00497
int count = 0;
00498 graph.
sortEdge();
00499
for(
Edge_Iterator it = graph.
getEdgeIterator(); it.
hasNext();){
00500
Edge e = it.
next();
00501
Node n1 = e.
getSRC();
00502
Node n2 = e.
getDST();
00503
if(!sameDisjointSet(n1, n2)){
00504
Node p1 = getParent(n1);
00505
Node p2 = getParent(n2);
00506
Color c1 = p1.
getPixel().
getColor();
00507
Color c2 = p2.
getPixel().
getColor();
00508
bool merge;
00509 {
00510
float* f1 = c1.
getFloat(
false);
00511
float* f2 = c2.
getFloat(
false);
00512 merge = Color::inWindow(f1,
radius, f2);
00513
delete[] f1;
00514
delete[] f2;
00515 }
00516
if(merge){
00517
int size1 = getParentSize(n1);
00518
int size2 = getParentSize(n2);
00519
Node n = mergeDisjointSet(n1, n2);
00520
Color newColor = Color::getAverageColor(c1, size1, c2, size2);
00521 n.
getPixel().
setColor(newColor);
00522 }
else {
00523 edge.push_back(e);
00524 }
00525 count++;
00526 }
00527 }
00528
Node ret = makeNode(
"MEANSHIFT");
00529 {
00530 ostringstream o;
00531 o <<
"Node size " << ret.
getNodeSize() <<
" edge " << ret.
getEdgeSize();
00532 trace.
print(o.str());
00533 }
00534
00535
return ret;
00536 }
00537
00538
00539
00540 Prune::Prune(
Node _graph,
int _area):
00541
DisjointSet_Base(_graph),
00542 area(_area){
00543 }
00544 Node Prune::makeDisjointSet(){
00545
Trace trace(
"Prune", channelCluster);
00546 {
00547 ostringstream o;
00548 o <<
"Node size " << graph.
getNodeSize() <<
" edge " << graph.
getEdgeSize();
00549 trace.
print(o.str());
00550 }
00551
int count = 0;
00552 graph.
sortEdge();
00553
for(
Edge_Iterator it = graph.
getEdgeIterator(); it.
hasNext();){
00554
Edge e = it.
next();
00555
Node n1 = e.
getSRC();
00556
Node n2 = e.
getDST();
00557
if(!sameDisjointSet(n1, n2)){
00558
Node p1 = getParent(n1);
00559
Node p2 = getParent(n2);
00560
Color c1 = p1.
getPixel().
getColor();
00561
Color c2 = p2.
getPixel().
getColor();
00562
int size1 = getParentSize(n1);
00563
int size2 = getParentSize(n2);
00564
00565
bool merge =
false;
00566
if((size1 <=
area) || (size2 <
area)){
00567 merge =
true;
00568 }
00569
if(merge){
00570
Node n = mergeDisjointSet(n1, n2);
00571
Color newColor = Color::getAverageColor(c1, size1, c2, size2);
00572 n.
getPixel().
setColor(newColor);
00573 }
else {
00574 edge.push_back(e);
00575 }
00576 count++;
00577 }
00578 }
00579
Node ret = makeNode(
"MEANSHIFT");
00580 {
00581 ostringstream o;
00582 o <<
"Node size " << ret.
getNodeSize() <<
" edge " << ret.
getEdgeSize();
00583 trace.
print(o.str());
00584 }
00585
00586
return ret;
00587 }
00588