00001
00002
00003
#include "MeanShift.hpp"
00004
00005
00006
00007 Lattice::Lattice(
PNG _image):
00008 image(_image),
00009 width(_image.getWidth()),
00010 height(_image.getHeight()){
00011
size =
width *
height;
00012
visited.reserve(
size);
00013
state.reserve(
size);
00014
for(
int i = 0; i <
size; i++){
00015
visited.push_back(0);
00016
state.push_back(ELATTICE_NO_MODE);
00017 }
00018
input.resize(size);
00019
output.resize(size);
00020 }
00021 int Lattice::getWidth()const{
00022
return width;
00023 }
00024 int Lattice::getHeight()const{
00025
return height;
00026 }
00027 FVector Lattice::getFVector(
int x,
int y){
00028
int rank = (y *
width) + x;
00029
if(
visited[rank] == 0){
00030
Color c =
image.
getColor(x, y);
00031
FVector v(x, y, c.
getHue(), c.
getSaturation(), c.
getValue());
00032
input[rank] = v;
00033
visited[rank]++;
00034 }
00035
return input[rank];
00036 }
00037 void Lattice::setOutputFVector(
const int x,
const int y,
const FVector v){
00038
int rank = (y *
width) + x;
00039
output[rank] = v;
00040 }
00041 FVector Lattice::getOutputFVector(
const int x,
const int y)
const{
00042
int rank = (y * width) + x;
00043
return output[rank];
00044 }
00045
00046 ELATTICE_STATE Lattice::getState(
int x,
int y)
const{
00047
int rank = (y * width) + x;
00048
return state[rank];
00049 }
00050 void Lattice::setState(
const int x,
const int y,
const ELATTICE_STATE _state){
00051
int rank = (y *
width) + x;
00052
state[rank] = _state;
00053 }
00054
00055
00056
00057 MeanShiftWindow::MeanShiftWindow(
int _size,
float _radius,
Lattice& _lattice):
00058 size(_size),
00059 radius(_radius),
00060 lattice(_lattice){
00061 }
00062
00063 string
MeanShiftWindow::toString()const{
00064 ostringstream o;
00065 o <<
"size " <<
size <<
" radius " <<
radius <<
" " <<
center.
toString();
00066 string ret = o.str();
00067
return ret;
00068 }
00069
00070 void MeanShiftWindow::setCenter(
const int x,
const int y){
00071
FVector point =
lattice.
getFVector(x, y);
00072
center = point;
00073 }
00074
00075 void MeanShiftWindow::findPixel(){
00076
data.clear();
00077
int width =
lattice.
getWidth();
00078
int height =
lattice.
getHeight();
00079
00080
int startX = ((
int)(
center[0])) -
size;
00081
if(startX < 0){
00082 startX = 0;
00083 }
00084
int startY = ((
int)(
center[1])) -
size;
00085
if(startY < 0){
00086 startY = 0;
00087 }
00088
int endX = ((
int)(
center[0])) +
size;
00089
if(endX >= width){
00090 endX = width - 1;
00091 }
00092
int endY = ((
int)(
center[1])) +
size;
00093
if(endY >= height){
00094 endY = height - 1;
00095 }
00096
float centerColor[3];
00097 centerColor[0] =
center[2];
00098 centerColor[1] = center[3];
00099 centerColor[2] = center[4];
00100
for(
int i = startX; i <= endX; i++){
00101
for(
int j = startY; j <= endY; j++){
00102
FVector ta =
lattice.
getFVector(i, j);
00103
float color[3];
00104 color[0] = ta[2];
00105 color[1] = ta[3];
00106 color[2] = ta[4];
00107
if(Color::inWindow(centerColor,
radius, color)){
00108
data.push_back(ta);
00109 }
00110 }
00111 }
00112 }
00113 float MeanShiftWindow::computeMean(){
00114
float wSize =
data.size();
00115
FVector mean;
00116
for(
int i = 0; i <
data.size(); i++){
00117
FVector fv =
data[i];
00118 mean += fv;
00119 }
00120 mean /= wSize;
00121
float ret =
norm(
center, mean);
00122
center = mean;
00123
00124
return ret;
00125 }
00126 FVector MeanShiftWindow::getCenter()
const{
00127
return center;
00128 }
00129
00130
00131
00132 MeanShift::MeanShift(
int _size,
float _radius,
PNG _image,
float _tol,
int _maxIter):
00133 size(_size),
00134 radius(_radius),
00135 lattice(_image),
00136 window(_size, _radius, lattice),
00137 tol(_tol),
00138 maxIter(_maxIter){
00139 }
00140 string
MeanShift::toString()const{
00141 ostringstream o;
00142 o <<
window.
toString() <<
"\n";
00143 string ret = o.str();
00144
return ret;
00145 }
00146
00147 void MeanShift::findMode(
const int i,
const int j){
00148
if(ELATTICE_HAS_MODE ==
lattice.
getState(i, j)){
00149
return;
00150 }
00151
00152
int count = 0;
00153
float norm = 0.0;
00154
FVector center;
00155 vector<FVector> lPoint;
00156
window.
setCenter(i, j);
00157
do{
00158
window.
findPixel();
00159
norm =
window.
computeMean();
00160 center =
window.
getCenter();
00161
int xCandidate = (
int)(center[0] + 0.5);
00162
int yCandidate = (
int)(center[1] + 0.5);
00163
FVector candidate =
lattice.
getFVector(xCandidate, yCandidate);
00164
float diff =
normColor(center, candidate);
00165
ELATTICE_STATE state =
lattice.
getState(xCandidate, yCandidate);
00166
if(diff < 0.01){
00167
switch(state){
00168
case ELATTICE_NO_MODE:
00169 {
00170 lPoint.push_back(candidate);
00171
lattice.
setState(xCandidate, yCandidate, ELATTICE_CANDIDATE_MODE);
00172 }
00173
break;
00174
case ELATTICE_CANDIDATE_MODE:
00175 {
00176 }
00177
break;
00178 }
00179 }
00180
if(state == ELATTICE_HAS_MODE){
00181
break;
00182 }
00183 count++;
00184 }
while(count < maxIter && norm >
tol);
00185
lattice.
setOutputFVector(i, j, center);
00186
lattice.
setState(i, j, ELATTICE_HAS_MODE);
00187
for(
int k = 0; k < lPoint.size(); k++){
00188
FVector v = lPoint[k];
00189
int x = v.
getX();
00190
int y = v.
getY();
00191
lattice.
setOutputFVector(x, y, v);
00192
lattice.
setState(x, y, ELATTICE_HAS_MODE);
00193 }
00194 }
00195 void MeanShift::doFilter(){
00196
Trace trace(
"doFilter", channelMeanShift);
00197
int width =
lattice.
getWidth();
00198
int height =
lattice.
getHeight();
00199
00200
for(
int i = 0; i < width; i++){
00201
for(
int j = 0; j < height; j++){
00202
findMode(i, j);
00203 }
00204 }
00205 }
00206
00207 Node MeanShift::makeNode(){
00208
int w =
lattice.
getWidth();
00209
int h =
lattice.
getHeight();
00210
PixelMap pmap(w, h);
00211
for(
int i = 0; i < w; i++){
00212
for(
int j = 0; j < h; j++){
00213
FVector v =
lattice.
getOutputFVector(i, j);
00214
Color color(v[2], v[3], v[4],
false);
00215
Pixel pixel(i, j, color);
00216 pmap.
addPixel(pixel);
00217 }
00218 }
00219
Node ret = pmap.
makeComponent();
00220
return ret;
00221 }