TooN 2.0.0-beta8
|
00001 #include <TooN/optimization/brent.h> 00002 #include <utility> 00003 #include <cmath> 00004 #include <cassert> 00005 #include <cstdlib> 00006 00007 namespace TooN{ 00008 namespace Internal{ 00009 00010 00011 ///Turn a multidimensional function in to a 1D function by specifying a 00012 ///point and direction. A nre function is defined: 00013 ////\f[ 00014 /// g(a) = \Vec{s} + a \Vec{d} 00015 ///\f] 00016 ///@ingroup gOptimize 00017 template<int Size, typename Precision, typename Func> struct LineSearch 00018 { 00019 const Vector<Size, Precision>& start; ///< \f$\Vec{s}\f$ 00020 const Vector<Size, Precision>& direction;///< \f$\Vec{d}\f$ 00021 00022 const Func& f;///< \f$f(\cdotp)\f$ 00023 00024 ///Set up the line search class. 00025 ///@param s Start point, \f$\Vec{s}\f$. 00026 ///@param d direction, \f$\Vec{d}\f$. 00027 ///@param func Function, \f$f(\cdotp)\f$. 00028 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func) 00029 :start(s),direction(d),f(func) 00030 {} 00031 00032 ///@param x Position to evaluate function 00033 ///@return \f$f(\vec{s} + x\vec{d})\f$ 00034 Precision operator()(Precision x) const 00035 { 00036 return f(start + x * direction); 00037 } 00038 }; 00039 00040 ///Bracket a 1D function by searching forward from zero. The assumption 00041 ///is that a minima exists in \f$f(x),\ x>0\f$, and this function searches 00042 ///for a bracket using exponentially growning or shrinking steps. 00043 ///@param a_val The value of the function at zero. 00044 ///@param func Function to bracket 00045 ///@param initial_lambda Initial stepsize 00046 ///@param zeps Minimum bracket size. 00047 ///@return <code>m[i][0]</code> contains the values of \f$x\f$ for the bracket, in increasing order, 00048 /// and <code>m[i][1]</code> contains the corresponding values of \f$f(x)\f$. If the bracket 00049 /// drops below the minimum bracket size, all zeros are returned. 00050 ///@ingroup gOptimize 00051 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps) 00052 { 00053 //Get a, b, c to bracket a minimum along a line 00054 Precision a, b, c, b_val, c_val; 00055 00056 a=0; 00057 00058 //Search forward in steps of lambda 00059 Precision lambda=initial_lambda; 00060 b = lambda; 00061 b_val = func(b); 00062 00063 while(std::isnan(b_val)) 00064 { 00065 //We've probably gone in to an invalid region. This can happen even 00066 //if following the gradient would never get us there. 00067 //try backing off lambda 00068 lambda*=.5; 00069 b = lambda; 00070 b_val = func(b); 00071 00072 } 00073 00074 00075 if(b_val < a_val) //We've gone downhill, so keep searching until we go back up 00076 { 00077 double last_good_lambda = lambda; 00078 00079 for(;;) 00080 { 00081 lambda *= 2; 00082 c = lambda; 00083 c_val = func(c); 00084 00085 if(std::isnan(c_val)) 00086 break; 00087 last_good_lambda = lambda; 00088 if(c_val > b_val) // we have a bracket 00089 break; 00090 else 00091 { 00092 a = b; 00093 a_val = b_val; 00094 b=c; 00095 b_val=c_val; 00096 00097 } 00098 } 00099 00100 //We took a step too far. 00101 //Back up: this will not attempt to ensure a bracket 00102 if(std::isnan(c_val)) 00103 { 00104 double bad_lambda=lambda; 00105 double l=1; 00106 00107 for(;;) 00108 { 00109 l*=.5; 00110 c = last_good_lambda + (bad_lambda - last_good_lambda)*l; 00111 c_val = func(c); 00112 00113 if(!std::isnan(c_val)) 00114 break; 00115 } 00116 00117 00118 } 00119 00120 } 00121 else //We've overshot the minimum, so back up 00122 { 00123 c = b; 00124 c_val = b_val; 00125 //Here, c_val > a_val 00126 00127 for(;;) 00128 { 00129 lambda *= .5; 00130 b = lambda; 00131 b_val = func(b); 00132 00133 if(b_val < a_val)// we have a bracket 00134 break; 00135 else if(lambda < zeps) 00136 return Zeros; 00137 else //Contract the bracket 00138 { 00139 c = b; 00140 c_val = b_val; 00141 } 00142 } 00143 } 00144 00145 Matrix<3,2> ret; 00146 ret[0] = makeVector(a, a_val); 00147 ret[1] = makeVector(b, b_val); 00148 ret[2] = makeVector(c, c_val); 00149 00150 return ret; 00151 } 00152 00153 } 00154 00155 00156 /** This class provides a nonlinear conjugate-gradient optimizer. The following 00157 code snippet will perform an optimization on the Rosenbrock Bananna function in 00158 two dimensions: 00159 00160 @code 00161 double Rosenbrock(const Vector<2>& v) 00162 { 00163 return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0])); 00164 } 00165 00166 Vector<2> RosenbrockDerivatives(const Vector<2>& v) 00167 { 00168 double x = v[0]; 00169 double y = v[1]; 00170 00171 Vector<2> ret; 00172 ret[0] = -2+2*x-400*(y-sq(x))*x; 00173 ret[1] = 200*y-200*sq(x); 00174 00175 return ret; 00176 } 00177 00178 int main() 00179 { 00180 ConjugateGradient<2> cg(makeVector(0,0), Rosenbrock, RosenbrockDerivatives); 00181 00182 while(cg.iterate(Rosenbrock, RosenbrockDerivatives)) 00183 cout << "y_" << iteration << " = " << cg.y << endl; 00184 00185 cout << "Optimal value: " << cg.y << endl; 00186 } 00187 @endcode 00188 00189 The chances are that you will want to read the documentation for 00190 ConjugateGradient::ConjugateGradient and ConjugateGradient::iterate. 00191 00192 Linesearch is currently performed using golden-section search and conjugate 00193 vector updates are performed using the Polak-Ribiere equations. There many 00194 tunable parameters, and the internals are readily accessible, so alternative 00195 termination conditions etc can easily be substituted. However, ususally these 00196 will not be necessary. 00197 00198 @ingroup gOptimize 00199 */ 00200 template<int Size, class Precision=double> struct ConjugateGradient 00201 { 00202 const int size; ///< Dimensionality of the space. 00203 Vector<Size> g; ///< Gradient vector used by the next call to iterate() 00204 Vector<Size> h; ///< Conjugate vector to be searched along in the next call to iterate() 00205 Vector<Size> minus_h;///< negative of h as this is required to be passed into a function which uses references (so can't be temporary) 00206 Vector<Size> old_g; ///< Gradient vector used to compute $h$ in the last call to iterate() 00207 Vector<Size> old_h; ///< Conjugate vector searched along in the last call to iterate() 00208 Vector<Size> x; ///< Current position (best known point) 00209 Vector<Size> old_x; ///< Previous best known point (not set at construction) 00210 Precision y; ///< Function at \f$x\f$ 00211 Precision old_y; ///< Function at old_x 00212 00213 Precision tolerance; ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision. 00214 Precision epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00215 int max_iterations; ///< Maximum number of iterations. Defaults to \c size\f$*100\f$ 00216 00217 Precision bracket_initial_lambda;///< Initial stepsize used in bracketing the minimum for the line search. Defaults to 1. 00218 Precision linesearch_tolerance; ///< Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precision. 00219 Precision linesearch_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00220 int linesearch_max_iterations; ///< Maximum number of iterations in the linesearch. Defaults to 100. 00221 00222 Precision bracket_epsilon; ///<Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged. Defaults to 1e-20. 00223 00224 int iterations; ///< Number of iterations performed 00225 00226 ///Initialize the ConjugateGradient class with sensible values. 00227 ///@param start Starting point, \e x 00228 ///@param func Function \e f to compute \f$f(x)\f$ 00229 ///@param deriv Function to compute \f$\nabla f(x)\f$ 00230 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv) 00231 : size(start.size()), 00232 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size) 00233 { 00234 init(start, func(start), deriv(start)); 00235 } 00236 00237 ///Initialize the ConjugateGradient class with sensible values. 00238 ///@param start Starting point, \e x 00239 ///@param func Function \e f to compute \f$f(x)\f$ 00240 ///@param deriv \f$\nabla f(x)\f$ 00241 template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv) 00242 : size(start.size()), 00243 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size) 00244 { 00245 init(start, func(start), deriv); 00246 } 00247 00248 ///Initialize the ConjugateGradient class with sensible values. Used internally. 00249 ///@param start Starting point, \e x 00250 ///@param func \f$f(x)\f$ 00251 ///@param deriv \f$\nabla f(x)\f$ 00252 void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv) 00253 { 00254 00255 using std::numeric_limits; 00256 x = start; 00257 00258 //Start with the conjugate direction aligned with 00259 //the gradient 00260 g = deriv; 00261 h = g; 00262 minus_h=-h; 00263 00264 y = func; 00265 old_y = y; 00266 00267 tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00268 epsilon = 1e-20; 00269 max_iterations = size * 100; 00270 00271 bracket_initial_lambda = 1; 00272 00273 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00274 linesearch_epsilon = 1e-20; 00275 linesearch_max_iterations=100; 00276 00277 bracket_epsilon=1e-20; 00278 00279 iterations=0; 00280 } 00281 00282 00283 ///Perform a linesearch from the current point (x) along the current 00284 ///conjugate vector (h). The linesearch does not make use of derivatives. 00285 ///You probably do not want to use this function. See iterate() instead. 00286 ///This function updates: 00287 /// - x 00288 /// - old_c 00289 /// - y 00290 /// - old_y 00291 /// - iterations 00292 /// Note that the conjugate direction and gradient are not updated. 00293 /// If bracket_minimum_forward detects a local maximum, then essentially a zero 00294 /// sized step is taken. 00295 /// @param func Functor returning the function value at a given point. 00296 template<class Func> void find_next_point(const Func& func) 00297 { 00298 Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func); 00299 00300 //Always search in the conjugate direction (h) 00301 //First bracket a minimum. 00302 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon); 00303 00304 double a = bracket[0][0]; 00305 double b = bracket[1][0]; 00306 double c = bracket[2][0]; 00307 00308 double a_val = bracket[0][1]; 00309 double b_val = bracket[1][1]; 00310 double c_val = bracket[2][1]; 00311 00312 old_y = y; 00313 old_x = x; 00314 iterations++; 00315 00316 //Local maximum achieved! 00317 if(a==0 && b== 0 && c == 0) 00318 return; 00319 00320 //We should have a bracket here 00321 00322 if(c < b) 00323 { 00324 //Failed to bracket due to NaN, so c is the best known point. 00325 //Simply go there. 00326 x-=h * c; 00327 y=c_val; 00328 00329 } 00330 else 00331 { 00332 assert(a < b && b < c); 00333 assert(a_val > b_val && b_val < c_val); 00334 00335 //Find the real minimum 00336 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon); 00337 00338 assert(m[0] >= a && m[0] <= c); 00339 assert(m[1] <= b_val); 00340 00341 //Update the current position and value 00342 x -= m[0] * h; 00343 y = m[1]; 00344 } 00345 } 00346 00347 ///Check to see it iteration should stop. You probably do not want to use 00348 ///this function. See iterate() instead. This function updates nothing. 00349 bool finished() 00350 { 00351 using std::abs; 00352 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon); 00353 } 00354 00355 ///After an iteration, update the gradient and conjugate using the 00356 ///Polak-Ribiere equations. 00357 ///This function updates: 00358 ///- g 00359 ///- old_g 00360 ///- h 00361 ///- old_h 00362 ///@param grad The derivatives of the function at \e x 00363 void update_vectors_PR(const Vector<Size>& grad) 00364 { 00365 //Update the position, gradient and conjugate directions 00366 old_g = g; 00367 old_h = h; 00368 00369 g = grad; 00370 //Precision gamma = (g * g - oldg*g)/(oldg * oldg); 00371 Precision gamma = (g * g - old_g*g)/(old_g * old_g); 00372 h = g + gamma * old_h; 00373 minus_h=-h; 00374 } 00375 00376 ///Use this function to iterate over the optimization. Note that after 00377 ///iterate returns false, g, h, old_g and old_h will not have been 00378 ///updated. 00379 ///This function updates: 00380 /// - x 00381 /// - old_c 00382 /// - y 00383 /// - old_y 00384 /// - iterations 00385 /// - g* 00386 /// - old_g* 00387 /// - h* 00388 /// - old_h* 00389 /// *'d variables not updated on the last iteration. 00390 ///@param func Functor returning the function value at a given point. 00391 ///@param deriv Functor to compute derivatives at the specified point. 00392 ///@return Whether to continue. 00393 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv) 00394 { 00395 find_next_point(func); 00396 00397 if(!finished()) 00398 { 00399 update_vectors_PR(deriv(x)); 00400 return 1; 00401 } 00402 else 00403 return 0; 00404 } 00405 }; 00406 00407 }