Page principale | Hiérarchie des classes | Liste des classes | Liste des fichiers | Membres de classe | Membres de fichier

DisjointSet.cpp

Aller à la documentation de ce fichier.
00001 /// \file 00002 /// Implémentation des classes utiles pour le clustering de classes disjointes 00003 00004 #include "DisjointSet.hpp" 00005 /////////////////////////////////////////////////////////////////////////////// 00006 /// Classe : DisjointSet_Base 00007 /////////////////////////////////////////////////////////////////////////////// 00008 DisjointSet_Base::DisjointSet_Base(const Node _graph):graph(_graph){ 00009 // nombre classe 00010 nbClass = graph.getNodeSize(); 00011 int i = 0; 00012 for(Node_Iterator it = graph.getNodeIterator(); it.hasNext();){ 00013 // On parcourt tous les éléments du graphe 00014 Node node = it.next(); 00015 // MAJ du rang 00016 node.setRank(i); 00017 // Copie locale 00018 vNode.push_back(node); 00019 // Pas de parent 00020 parent.push_back(-1); 00021 // Copie de la taille 00022 // int sz = node.getNodeSize(); 00023 int sz = node.getLeafSize(); 00024 size.push_back(sz); 00025 i++; 00026 } 00027 } 00028 ///. 00029 /// Retourne l'index du parent 00030 /// Réalise la compression de chemin 00031 int DisjointSet_Base::getParentRank(int i){ 00032 // Si l'index est < 0 c'est un parent 00033 if(parent[i] < 0){ 00034 return i; 00035 } else { 00036 // Sinon, on cherche recursivement le parent 00037 return (parent[i] = getParentRank(parent[i])); 00038 } 00039 } 00040 ///. 00041 /// Fusion de 2 ensembles disjoint 00042 /// Réalise l'union par rang 00043 Node DisjointSet_Base::mergeDisjointSet(Node n1, Node n2){ 00044 // index des cluster 00045 int c1 = getParentRank(n1.getRank()); 00046 int c2 = getParentRank(n2.getRank()); 00047 // On fusionne que si les ensembles sont différent 00048 if(c1 != c2){ 00049 // Cluster résultat 00050 Node ret = Node::CreateLeaf(); 00051 if((-1 == parent[c1]) && (-1 == parent[c2])){ 00052 // Aucun des 2 clusters ne possède d'élément 00053 // MAJ parent 00054 parent[c1] += parent[c2]; 00055 parent[c2] = c1; 00056 // MAJ taille 00057 size[c1] += size[c2]; 00058 size[c2] = 0; 00059 // Creation d'un nouveau DisjointSet 00060 Node cluster = Node::CreateGraph("CLUSTER"); 00061 cluster.addNode(n1); 00062 cluster.addNode(n2); 00063 // MAJ hashtable des parents 00064 table[parent[c2]]= cluster; 00065 ret = cluster; 00066 } else if(parent[c1] <= parent[c2]){ 00067 // On a au moins 1 cluster 00068 if(parent[c2] == -1){ 00069 // 1 cluster et un noeud feuille : on ajoute simplement le noeud feuille 00070 Node n = table.find(c1)->second; 00071 n.addNode(n2); 00072 ret = n; 00073 } else { 00074 // 2 clusters : on fait l'union par le rang 00075 Node n1 = table.find(c1)->second; 00076 table_iter it = table.find(c2); 00077 Node n2 = it->second; 00078 // Fusion 00079 n1.merge(n2); 00080 // MAJ hashtable des parents 00081 table.erase(it); 00082 ret = n1; 00083 00084 00085 } 00086 // MAJ parent 00087 parent[c1] += parent[c2]; 00088 parent[c2] = c1; 00089 // MAJ taille 00090 size[c1] += size[c2]; 00091 size[c2] = 0; 00092 } else { 00093 if(parent[c1] == -1){ 00094 // 1 cluster et un noeud feuille : on ajoute simplement le noeud feuille 00095 Node n = table.find(c2)->second; 00096 n.addNode(n1); 00097 ret = n; 00098 } else { 00099 // 2 clusters : on fait l'union par le rang 00100 table_iter it = table.find(c1); 00101 Node n1 = it->second; 00102 Node n2 = table.find(c2)->second; 00103 // Fusion 00104 n2.merge(n1); 00105 // MAJ hashtable des parents 00106 table.erase(it); 00107 ret = n2; 00108 } 00109 // MAJ parent 00110 parent[c2] += parent[c1]; 00111 parent[c1] = c2; 00112 // MAJ taille 00113 size[c2] += size[c1]; 00114 size[c1] = 0; 00115 } 00116 // Fusion : 1 classe de moins 00117 nbClass--; 00118 return ret; 00119 } 00120 } 00121 string DisjointSet_Base::toString(){ 00122 std::ostringstream o; 00123 o << "Classes " << endl; 00124 for(int i = 0; i < parent.size(); i++){ 00125 o << i << "\t"; 00126 } 00127 o << endl; 00128 for(int i = 0; i < parent.size(); i++){ 00129 o << i << "\t" << parent[i] << "\n"; 00130 } 00131 o << endl; 00132 o << "Classes number : " << nbClass; 00133 return o.str(); 00134 } 00135 bool DisjointSet_Base::sameDisjointSet(Node n1, Node n2){ 00136 // Compare l'index des parents 00137 int c1 = getParentRank(n1.getRank()); 00138 int c2 = getParentRank(n2.getRank()); 00139 return c1 == c2; 00140 } 00141 ///. 00142 /// Construit un noeud après la fusion 00143 Node DisjointSet_Base::makeNode(string label){ 00144 Trace trace("makeNode", channelGraph); 00145 // Valeur de retour; 00146 Node ret = Node::CreateGraph(label); 00147 // Vecteur des parents 00148 vector<Node> parentNode; 00149 for(int i = 0; i < parent.size(); i++){ 00150 // On parcourt les parents 00151 if(parent[i] < 0){ 00152 // Noeud parent 00153 if(parent[i] == -1){ 00154 // Parent isolé, on crée un cluster avec un noeud 00155 Node cluster = Node::CreateGraph("CLUSTER"); 00156 cluster.addNode(vNode[i]); 00157 cluster.setPixel(vNode[i].getPixel()); 00158 cluster.setWeight(vNode[i].getWeight()); 00159 ret.addNode(cluster); 00160 // On ajoute le cluster à la liste 00161 parentNode.push_back(cluster); 00162 } else { 00163 // On récupère le noeud de la hashtable 00164 Node cluster = table.find(i)->second; 00165 ret.addNode(cluster); 00166 // On ajoute le cluster à la liste 00167 parentNode.push_back(cluster); 00168 } 00169 } else { 00170 int r = getParentRank(i); 00171 Node cluster = table.find(r)->second; 00172 // On ajoute le cluster à la liste 00173 parentNode.push_back(cluster); 00174 } 00175 } 00176 00177 // A ce stade on a un vecteur de parent 00178 // Il reste à ajouter les arrêtes d'adjacence 00179 00180 // Hashtable des arres 00181 typedef pair<int, int> aPair; 00182 map< aPair , Edge> table; 00183 typedef map< aPair , Edge>::iterator table_iter; 00184 00185 for(int i = 0; i < edge.size(); i++){ 00186 // On parcourt les arrêtes 00187 Edge aEdge = edge[i]; 00188 // Recherche du rang des parents 00189 int r1 = aEdge.getSRC().getRank(); 00190 int r2 = aEdge.getDST().getRank(); 00191 int pr1 = getParentRank(r1); 00192 int pr2 = getParentRank(r2); 00193 if(pr1 == pr2){ 00194 // Si les index sont égaux, l'arrête relie 00195 // 2 noeuds d'un même cluster 00196 continue; 00197 } 00198 // On cherche s'il y a déjà une arrête reliant les 2 clusters 00199 aPair pair1 = make_pair(pr1, pr2); 00200 aPair pair2 = make_pair(pr2, pr1); 00201 table_iter b1 = table.find(pair1); 00202 table_iter b2 = table.find(pair2); 00203 table_iter e = table.end(); 00204 if((table.size() == 0) || ((b1 == e) && (b2 == e)) ){ 00205 // On peut ajouter l'arrête : on l'a crée 00206 Node p1 = parentNode[r1]; 00207 Node p2 = parentNode[r2]; 00208 float weight = aEdge.getWeight(); 00209 Edge newEdge(p1, p2, weight); 00210 // On recopie les information de l'ancienne arrête 00211 int x0 = aEdge.getX0(); 00212 int y0 = aEdge.getY0(); 00213 int x1 = aEdge.getX1(); 00214 int y1 = aEdge.getY1(); 00215 newEdge.setCoord(x0, y0, x1, y1); 00216 ret.addEdge(newEdge); 00217 // MAJ hashtable 00218 table.insert(make_pair(pair1, newEdge)); 00219 } 00220 } 00221 return ret; 00222 } 00223 int DisjointSet_Base::getParentSize(Node node){ 00224 int rank = getParentRank(node.getRank()); 00225 return size[rank]; 00226 } 00227 Node DisjointSet_Base::getParent(Node node){ 00228 int rank = getParentRank(node.getRank()); 00229 if(parent[rank] == -1){ 00230 // Noeud feuille 00231 return node; 00232 } else { 00233 // Recherche du cluster dans la hashtable 00234 table_iter it = table.find(rank); 00235 return it->second; 00236 } 00237 } 00238 /////////////////////////////////////////////////////////////////////////////// 00239 /// Classe : Kruskal 00240 /// Implemente l'algorithme de Kruskal pour trouver le MST d'un graphe 00241 /////////////////////////////////////////////////////////////////////////////// 00242 Kruskal::Kruskal(Node _graph):DisjointSet_Base(_graph){} 00243 Node Kruskal::makeDisjointSet(){ 00244 Trace trace("makeDisjointSet", channelGraph); 00245 trace.print("sort"); 00246 // Tri des arrêtes par poids croissant 00247 graph.sortEdge(); 00248 // Parcourt des arrêtes 00249 for(Edge_Iterator it = graph.getEdgeIterator(); it.hasNext();){ 00250 Edge e = it.next(); 00251 // Source et destination 00252 Node n1 = e.getSRC(); 00253 Node n2 = e.getDST(); 00254 00255 if(!sameDisjointSet(n1, n2)){ 00256 // Si la source et la destination ne sont pas dans le 00257 // même cluster, on fusionne 00258 Node n = mergeDisjointSet(n1, n2); 00259 // Ajout de l'arrête du MST 00260 n.addEdge(e); 00261 } 00262 } 00263 Node ret = makeNode("KRUSKAL"); 00264 return ret; 00265 } 00266 /////////////////////////////////////////////////////////////////////////////// 00267 /// Classe : MInt 00268 /// Implemente l'algorithme de l'article pour segmenter une image 00269 /////////////////////////////////////////////////////////////////////////////// 00270 MInt::MInt(Node _graph, float _tau): 00271 DisjointSet_Base(_graph), 00272 tau(_tau){ 00273 } 00274 Node MInt::makeDisjointSet(){ 00275 Trace trace("MInt", channelGraph); 00276 00277 { 00278 ostringstream o; 00279 o << "Node size " << graph.getNodeSize() << " edge " << graph.getEdgeSize() << " tau " << tau; 00280 trace.print(o.str()); 00281 } 00282 00283 trace.print("sort"); 00284 graph.sortEdge(); 00285 00286 00287 int i = 0; 00288 for(Edge_Iterator it = graph.getEdgeIterator(); it.hasNext();){ 00289 Edge e = it.next(); 00290 Node n1 = e.getSRC(); 00291 Node n2 = e.getDST(); 00292 if(!sameDisjointSet(e.getSRC(), e.getDST())){ 00293 bool merge = false; 00294 float edgeWeight = e.getWeight(); 00295 { 00296 int size1 = getParentSize(n1); 00297 int size2 = getParentSize(n2); 00298 00299 float f1 = getParent(n1).getWeight();//getInt(n1); 00300 float f2 = getParent(n2).getWeight();//getInt(n2); 00301 00302 float k1 = (tau / size1); 00303 float k2 = (tau / size2); 00304 00305 float mint1 = f1 + k1; 00306 float mint2 = f2 + k2; 00307 00308 float mint = mint1; 00309 if(mint2 < mint){ 00310 mint = mint2; 00311 } 00312 if(edgeWeight <= mint){ 00313 merge = true; 00314 } 00315 } 00316 if(merge){ 00317 int size1 = getParentSize(n1); 00318 int size2 = getParentSize(n2); 00319 Color c1 = getParent(n1).getPixel().getColor(); 00320 Color c2 = getParent(n2).getPixel().getColor(); 00321 Node n = mergeDisjointSet(n1, n2); 00322 Color newColor = Color::getAverageColor(c1, size1, c2, size2); 00323 n.getPixel().setColor(newColor); 00324 n.setWeight(edgeWeight); 00325 n.addEdge(e); 00326 } else { 00327 edge.push_back(e); 00328 } 00329 } 00330 i++; 00331 } 00332 Node ret = makeNode("SEGMENT"); 00333 return ret; 00334 } 00335 /////////////////////////////////////////////////////////////////////////////// 00336 /// Classe : Segment 00337 /// Variation de l'algorithme de l'article pour segmenter une image 00338 /////////////////////////////////////////////////////////////////////////////// 00339 Segment::Segment(Node _graph, float _k): 00340 DisjointSet_Base(_graph), 00341 k(_k){ 00342 } 00343 Node Segment::makeDisjointSet(){ 00344 Trace trace("makeDisjointSet", channelCluster); 00345 { 00346 ostringstream o; 00347 o << "Node size " << graph.getNodeSize() << " edge " << graph.getEdgeSize(); 00348 trace.print(o.str()); 00349 } 00350 { 00351 Trace trace("sort edge", channelCluster); 00352 graph.sortEdge(edgeWeightMin, edgeWeightMax); 00353 edgeWeightDelta = edgeWeightMax - edgeWeightMin; 00354 ostringstream o; 00355 o << "Delta " << edgeWeightDelta; 00356 trace.print(o.str()); 00357 } 00358 /* 00359 { 00360 PixelMap* map = graph.makeMap(1); 00361 map->parentToFile("Parent.dat"); 00362 delete map; 00363 } 00364 ofstream out1("Segment.dat"); 00365 */ 00366 int i = 0; 00367 for(Edge_Iterator it = graph.getEdgeIterator(); it.hasNext();){ 00368 Edge e = it.next(); 00369 Node n1 = e.getSRC(); 00370 Node n2 = e.getDST(); 00371 if(!sameDisjointSet(e.getSRC(), e.getDST())){ 00372 bool merge = false; 00373 float edgeWeight = e.getWeight(); 00374 { 00375 bool normalize = true; 00376 00377 int size1 = getParentSize(n1); 00378 int size2 = getParentSize(n2); 00379 int size3 = size2 + size1; 00380 00381 float f1 = getParent(n1).getWeight(); 00382 float f2 = getParent(n2).getWeight(); 00383 00384 /* 00385 * Test géométrique 00386 f1 = f1 / size1; 00387 f2 = f2 / size2; 00388 float K = 200.0; 00389 if(size1 <= k){ 00390 f1 = edgeWeightMax * size1; 00391 } else { 00392 f1 = f1 + (K * edgeWeightDelta) / size1; 00393 } 00394 if(size1 <= k){ 00395 f2 = edgeWeightMax * size2; 00396 } else { 00397 f2 = f2 + (K * edgeWeightDelta) / size2; 00398 } 00399 if(normalize){ 00400 } 00401 edgeWeight = (edgeWeight - edgeWeightMin) / edgeWeightDelta; 00402 f1 = (f1 - edgeWeightMin) / edgeWeightDelta; 00403 f2 = (f2 - edgeWeightMin) / edgeWeightDelta; 00404 00405 00406 float avg = (f1 + f2 + edgeWeight) / 3; 00407 00408 float error1 = k * f1; 00409 float error2 = k * f2; 00410 float diff = 0.0; 00411 float error1 = (k * edgeWeightDelta) / size1; 00412 float error2 = (k * edgeWeightDelta) / size2; 00413 */ 00414 00415 00416 float k1 = (k * edgeWeightDelta) / size1; 00417 float k2 = (k * edgeWeightDelta) / size2; 00418 00419 f1 = f1 / size1 + k1; 00420 f2 = f2 / size2 + k2; 00421 00422 float estimation = (f1 + f2 + edgeWeight) / 3; 00423 // float estimation = (f1 * size1 + f2 * size2 + edgeWeight) / size3; 00424 00425 if(edgeWeight <= estimation){ 00426 merge = true; 00427 } else { 00428 merge = false; 00429 } 00430 if(merge){ 00431 // if(true){ 00432 int p1 = getParentRank(n1.getRank()); 00433 int p2 = getParentRank(n2.getRank()); 00434 Color c1 = getParent(n1).getPixel().getColor(); 00435 Color c2 = getParent(n2).getPixel().getColor(); 00436 int size1 = getParentSize(n1); 00437 int size2 = getParentSize(n2); 00438 Node n = mergeDisjointSet(n1, n2); 00439 Color newColor = Color::getAverageColor(c1, size1, c2, size2); 00440 n.getPixel().setColor(newColor); 00441 int _p1 = getParentRank(n1.getRank()); 00442 int _p2 = getParentRank(n2.getRank()); 00443 00444 /* 00445 if(p1 != _p1){ 00446 ostringstream o; 00447 o << setprecision(5) << fixed << setw(5) << p1 << setw(5) << _p1 << setw(5) << size1 << setw(10) << k1 << setw(10) << k2 << setw(10) << f1 << setw(10) << f2 << setw(10) << estimation << setw(10) << edgeWeight; 00448 out1 << o.str() << endl; 00449 } else if(p2 != _p2){ 00450 ostringstream o; 00451 o << setprecision(5) << fixed << setw(5) << p2 << setw(5) << _p2 << setw(5) << size2 << setw(10) << k1 << setw(10) << k2 << setw(10) << f1 << setw(10) << f2 << setw(10) << estimation << setw(10) << edgeWeight; 00452 out1 << o.str() << endl; 00453 } 00454 */ 00455 00456 float sum = n.getWeight() + edgeWeight; 00457 n.setWeight(sum); 00458 n.addEdge(e); 00459 } else { 00460 edge.push_back(e); 00461 00462 /* 00463 ostringstream o; 00464 o << setprecision(5) << fixed << setw(5) << (-size1) << setw(5) << (-size2) << setw(5) << size3 << setw(10) << k1 << setw(10) << k2 << setw(10) << f1 << setw(10) << f2 << setw(10) << estimation << setw(10) << edgeWeight; 00465 out1 << o.str() << endl; 00466 */ 00467 00468 } 00469 } 00470 } 00471 i++; 00472 } 00473 Node ret = makeNode("SEGMENT"); 00474 // out1.close(); 00475 { 00476 ostringstream o; 00477 o << "Node size " << ret.getNodeSize() << " edge " << ret.getEdgeSize(); 00478 trace.print(o.str()); 00479 } 00480 00481 return ret; 00482 } 00483 /////////////////////////////////////////////////////////////////////////////// 00484 /// MeanshiftFusion 00485 /////////////////////////////////////////////////////////////////////////////// 00486 MeanShiftFusion::MeanShiftFusion(Node _graph, float _radius): 00487 DisjointSet_Base(_graph), 00488 radius(_radius){ 00489 } 00490 Node MeanShiftFusion::makeDisjointSet(){ 00491 Trace trace("Fusion", channelCluster); 00492 { 00493 ostringstream o; 00494 o << "Node size " << graph.getNodeSize() << " edge " << graph.getEdgeSize(); 00495 trace.print(o.str()); 00496 } 00497 int count = 0; 00498 graph.sortEdge(); 00499 for(Edge_Iterator it = graph.getEdgeIterator(); it.hasNext();){ 00500 Edge e = it.next(); 00501 Node n1 = e.getSRC(); 00502 Node n2 = e.getDST(); 00503 if(!sameDisjointSet(n1, n2)){ 00504 Node p1 = getParent(n1); 00505 Node p2 = getParent(n2); 00506 Color c1 = p1.getPixel().getColor(); 00507 Color c2 = p2.getPixel().getColor(); 00508 bool merge; 00509 { 00510 float* f1 = c1.getFloat(false); 00511 float* f2 = c2.getFloat(false); 00512 merge = Color::inWindow(f1, radius, f2); 00513 delete[] f1; 00514 delete[] f2; 00515 } 00516 if(merge){ 00517 int size1 = getParentSize(n1); 00518 int size2 = getParentSize(n2); 00519 Node n = mergeDisjointSet(n1, n2); 00520 Color newColor = Color::getAverageColor(c1, size1, c2, size2); 00521 n.getPixel().setColor(newColor); 00522 } else { 00523 edge.push_back(e); 00524 } 00525 count++; 00526 } 00527 } 00528 Node ret = makeNode("MEANSHIFT"); 00529 { 00530 ostringstream o; 00531 o << "Node size " << ret.getNodeSize() << " edge " << ret.getEdgeSize(); 00532 trace.print(o.str()); 00533 } 00534 00535 return ret; 00536 } 00537 /////////////////////////////////////////////////////////////////////////////// 00538 /// Prune 00539 /////////////////////////////////////////////////////////////////////////////// 00540 Prune::Prune(Node _graph, int _area): 00541 DisjointSet_Base(_graph), 00542 area(_area){ 00543 } 00544 Node Prune::makeDisjointSet(){ 00545 Trace trace("Prune", channelCluster); 00546 { 00547 ostringstream o; 00548 o << "Node size " << graph.getNodeSize() << " edge " << graph.getEdgeSize(); 00549 trace.print(o.str()); 00550 } 00551 int count = 0; 00552 graph.sortEdge(); 00553 for(Edge_Iterator it = graph.getEdgeIterator(); it.hasNext();){ 00554 Edge e = it.next(); 00555 Node n1 = e.getSRC(); 00556 Node n2 = e.getDST(); 00557 if(!sameDisjointSet(n1, n2)){ 00558 Node p1 = getParent(n1); 00559 Node p2 = getParent(n2); 00560 Color c1 = p1.getPixel().getColor(); 00561 Color c2 = p2.getPixel().getColor(); 00562 int size1 = getParentSize(n1); 00563 int size2 = getParentSize(n2); 00564 00565 bool merge = false; 00566 if((size1 <= area) || (size2 < area)){ 00567 merge = true; 00568 } 00569 if(merge){ 00570 Node n = mergeDisjointSet(n1, n2); 00571 Color newColor = Color::getAverageColor(c1, size1, c2, size2); 00572 n.getPixel().setColor(newColor); 00573 } else { 00574 edge.push_back(e); 00575 } 00576 count++; 00577 } 00578 } 00579 Node ret = makeNode("MEANSHIFT"); 00580 { 00581 ostringstream o; 00582 o << "Node size " << ret.getNodeSize() << " edge " << ret.getEdgeSize(); 00583 trace.print(o.str()); 00584 } 00585 00586 return ret; 00587 } 00588

Généré le Thu Jul 1 23:13:32 2004 pour segment par doxygen 1.3.7