TooN 2.0.0-beta8
optimization/conjugate_gradient.h
00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006 
00007 namespace TooN{
00008     namespace Internal{
00009 
00010 
00011     ///Turn a multidimensional function in to a 1D function by specifying a
00012     ///point and direction. A nre function is defined:
00013     ////\f[
00014     /// g(a) = \Vec{s} + a \Vec{d}
00015     ///\f]
00016     ///@ingroup gOptimize
00017     template<int Size, typename Precision, typename Func> struct LineSearch
00018     {
00019         const Vector<Size, Precision>& start; ///< \f$\Vec{s}\f$
00020         const Vector<Size, Precision>& direction;///< \f$\Vec{d}\f$
00021 
00022         const Func& f;///< \f$f(\cdotp)\f$
00023 
00024         ///Set up the line search class.
00025         ///@param s Start point, \f$\Vec{s}\f$.
00026         ///@param d direction, \f$\Vec{d}\f$.
00027         ///@param func Function, \f$f(\cdotp)\f$.
00028         LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00029         :start(s),direction(d),f(func)
00030         {}
00031 
00032         ///@param x Position to evaluate function
00033         ///@return \f$f(\vec{s} + x\vec{d})\f$
00034         Precision operator()(Precision x) const
00035         {
00036             return f(start + x * direction);
00037         }
00038     };
00039 
00040     ///Bracket a 1D function by searching forward from zero. The assumption
00041     ///is that a minima exists in \f$f(x),\ x>0\f$, and this function searches
00042     ///for a bracket using exponentially growning or shrinking steps.
00043     ///@param a_val The value of the function at zero.
00044     ///@param func Function to bracket
00045     ///@param initial_lambda Initial stepsize
00046     ///@param zeps Minimum bracket size.
00047     ///@return <code>m[i][0]</code> contains the values of \f$x\f$ for the bracket, in increasing order,
00048     ///        and <code>m[i][1]</code> contains the corresponding values of \f$f(x)\f$. If the bracket 
00049     ///        drops below the minimum bracket size, all zeros are returned.
00050     ///@ingroup gOptimize
00051     template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
00052     {
00053         //Get a, b, c to  bracket a minimum along a line
00054         Precision a, b, c, b_val, c_val;
00055 
00056         a=0;
00057 
00058         //Search forward in steps of lambda
00059         Precision lambda=initial_lambda;
00060         b = lambda;
00061         b_val = func(b);
00062 
00063         while(std::isnan(b_val))
00064         {
00065             //We've probably gone in to an invalid region. This can happen even 
00066             //if following the gradient would never get us there.
00067             //try backing off lambda
00068             lambda*=.5;
00069             b = lambda;
00070             b_val = func(b);
00071 
00072         }
00073 
00074 
00075         if(b_val < a_val) //We've gone downhill, so keep searching until we go back up
00076         {
00077             double last_good_lambda = lambda;
00078             
00079             for(;;)
00080             {
00081                 lambda *= 2;
00082                 c = lambda;
00083                 c_val = func(c);
00084 
00085                 if(std::isnan(c_val))
00086                     break;
00087                 last_good_lambda = lambda;
00088                 if(c_val >  b_val) // we have a bracket
00089                     break;
00090                 else
00091                 {
00092                     a = b;
00093                     a_val = b_val;
00094                     b=c;
00095                     b_val=c_val;
00096 
00097                 }
00098             }
00099 
00100             //We took a step too far.
00101             //Back up: this will not attempt to ensure a bracket
00102             if(std::isnan(c_val))
00103             {
00104                 double bad_lambda=lambda;
00105                 double l=1;
00106 
00107                 for(;;)
00108                 {
00109                     l*=.5;
00110                     c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
00111                     c_val = func(c);
00112 
00113                     if(!std::isnan(c_val))
00114                         break;
00115                 }
00116 
00117 
00118             }
00119 
00120         }
00121         else //We've overshot the minimum, so back up
00122         {
00123             c = b;
00124             c_val = b_val;
00125             //Here, c_val > a_val
00126 
00127             for(;;)
00128             {
00129                 lambda *= .5;
00130                 b = lambda;
00131                 b_val = func(b);
00132 
00133                 if(b_val < a_val)// we have a bracket
00134                     break;
00135                 else if(lambda < zeps)
00136                     return Zeros;
00137                 else //Contract the bracket
00138                 {
00139                     c = b;
00140                     c_val = b_val;
00141                 }
00142             }
00143         }
00144 
00145         Matrix<3,2> ret;
00146         ret[0] = makeVector(a, a_val);
00147         ret[1] = makeVector(b, b_val);
00148         ret[2] = makeVector(c, c_val);
00149 
00150         return ret;
00151     }
00152 
00153 }
00154 
00155 
00156 /** This class provides a nonlinear conjugate-gradient optimizer. The following
00157 code snippet will perform an optimization on the Rosenbrock Bananna function in
00158 two dimensions:
00159 
00160 @code
00161 double Rosenbrock(const Vector<2>& v)
00162 {
00163         return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0]));
00164 }
00165 
00166 Vector<2> RosenbrockDerivatives(const Vector<2>& v)
00167 {
00168     double x = v[0];
00169     double y = v[1];
00170 
00171     Vector<2> ret;
00172     ret[0] = -2+2*x-400*(y-sq(x))*x;
00173     ret[1] = 200*y-200*sq(x);
00174 
00175     return ret;
00176 }
00177 
00178 int main()
00179 {
00180     ConjugateGradient<2> cg(makeVector(0,0), Rosenbrock, RosenbrockDerivatives);
00181 
00182     while(cg.iterate(Rosenbrock, RosenbrockDerivatives))
00183         cout << "y_" << iteration << " = " << cg.y << endl;
00184 
00185     cout << "Optimal value: " << cg.y << endl;
00186 }
00187 @endcode
00188 
00189 The chances are that you will want to read the documentation for
00190 ConjugateGradient::ConjugateGradient and ConjugateGradient::iterate.
00191 
00192 Linesearch is currently performed using golden-section search and conjugate
00193 vector updates are performed using the Polak-Ribiere equations.  There many
00194 tunable parameters, and the internals are readily accessible, so alternative
00195 termination conditions etc can easily be substituted. However, ususally these
00196 will not be necessary.
00197 
00198 @ingroup gOptimize
00199 */
00200 template<int Size, class Precision=double> struct ConjugateGradient
00201 {
00202     const int size;      ///< Dimensionality of the space.
00203     Vector<Size> g;      ///< Gradient vector used by the next call to iterate()
00204     Vector<Size> h;      ///< Conjugate vector to be searched along in the next call to iterate()
00205     Vector<Size> minus_h;///< negative of h as this is required to be passed into a function which uses references (so can't be temporary)
00206     Vector<Size> old_g;  ///< Gradient vector used to compute $h$ in the last call to iterate()
00207     Vector<Size> old_h;  ///< Conjugate vector searched along in the last call to iterate()
00208     Vector<Size> x;      ///< Current position (best known point)
00209     Vector<Size> old_x;  ///< Previous best known point (not set at construction)
00210     Precision y;         ///< Function at \f$x\f$
00211     Precision old_y;     ///< Function at  old_x
00212 
00213     Precision tolerance; ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision.
00214     Precision epsilon;   ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20
00215     int       max_iterations; ///< Maximum number of iterations. Defaults to \c size\f$*100\f$
00216 
00217     Precision bracket_initial_lambda;///< Initial stepsize used in bracketing the minimum for the line search. Defaults to 1.
00218     Precision linesearch_tolerance; ///< Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precision.
00219     Precision linesearch_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20
00220     int linesearch_max_iterations;  ///< Maximum number of iterations in the linesearch. Defaults to 100.
00221 
00222     Precision bracket_epsilon; ///<Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged. Defaults to 1e-20.
00223 
00224     int iterations; ///< Number of iterations performed
00225 
00226     ///Initialize the ConjugateGradient class with sensible values.
00227     ///@param start Starting point, \e x
00228     ///@param func  Function \e f  to compute \f$f(x)\f$
00229     ///@param deriv  Function to compute \f$\nabla f(x)\f$
00230     template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00231     : size(start.size()),
00232       g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00233     {
00234         init(start, func(start), deriv(start));
00235     }   
00236 
00237     ///Initialize the ConjugateGradient class with sensible values.
00238     ///@param start Starting point, \e x
00239     ///@param func  Function \e f  to compute \f$f(x)\f$
00240     ///@param deriv  \f$\nabla f(x)\f$
00241     template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
00242     : size(start.size()),
00243       g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00244     {
00245         init(start, func(start), deriv);
00246     }   
00247 
00248     ///Initialize the ConjugateGradient class with sensible values. Used internally.
00249     ///@param start Starting point, \e x
00250     ///@param func  \f$f(x)\f$
00251     ///@param deriv  \f$\nabla f(x)\f$
00252     void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
00253     {
00254 
00255         using std::numeric_limits;
00256         x = start;
00257 
00258         //Start with the conjugate direction aligned with
00259         //the gradient
00260         g = deriv;
00261         h = g;
00262         minus_h=-h;
00263 
00264         y = func;
00265         old_y = y;
00266 
00267         tolerance = sqrt(numeric_limits<Precision>::epsilon());
00268         epsilon = 1e-20;
00269         max_iterations = size * 100;
00270 
00271         bracket_initial_lambda = 1;
00272 
00273         linesearch_tolerance =  sqrt(numeric_limits<Precision>::epsilon());
00274         linesearch_epsilon = 1e-20;
00275         linesearch_max_iterations=100;
00276 
00277         bracket_epsilon=1e-20;
00278 
00279         iterations=0;
00280     }
00281 
00282 
00283     ///Perform a linesearch from the current point (x) along the current
00284     ///conjugate vector (h).  The linesearch does not make use of derivatives.
00285     ///You probably do not want to use this function. See iterate() instead.
00286     ///This function updates:
00287     /// - x
00288     /// - old_c
00289     /// - y
00290     /// - old_y
00291     /// - iterations
00292     /// Note that the conjugate direction and gradient are not updated.
00293     /// If bracket_minimum_forward detects a local maximum, then essentially a zero
00294     /// sized step is taken.
00295     /// @param func Functor returning the function value at a given point.
00296     template<class Func> void find_next_point(const Func& func)
00297     {
00298         Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
00299 
00300         //Always search in the conjugate direction (h)
00301         //First bracket a minimum.
00302         Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
00303         
00304         double a = bracket[0][0];
00305         double b = bracket[1][0];
00306         double c = bracket[2][0];
00307 
00308         double a_val = bracket[0][1];
00309         double b_val = bracket[1][1];
00310         double c_val = bracket[2][1];
00311 
00312         old_y = y;
00313         old_x = x;
00314         iterations++;
00315         
00316         //Local maximum achieved!
00317         if(a==0 && b== 0 && c == 0)
00318             return;
00319 
00320         //We should have a bracket here
00321 
00322         if(c < b)
00323         {
00324             //Failed to bracket due to NaN, so c is the best known point.
00325             //Simply go there.
00326             x-=h * c;
00327             y=c_val;
00328 
00329         }
00330         else
00331         {
00332             assert(a < b && b < c);
00333             assert(a_val > b_val && b_val < c_val);
00334 
00335             //Find the real minimum
00336             Vector<2, Precision>  m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
00337 
00338             assert(m[0] >= a && m[0] <= c);
00339             assert(m[1] <= b_val);
00340 
00341             //Update the current position and value
00342             x -= m[0] * h;
00343             y = m[1];
00344         }
00345     }
00346 
00347     ///Check to see it iteration should stop. You probably do not want to use
00348     ///this function. See iterate() instead. This function updates nothing.
00349     bool finished()
00350     {
00351         using std::abs;
00352         return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00353     }
00354 
00355     ///After an iteration, update the gradient and conjugate using the
00356     ///Polak-Ribiere equations.
00357     ///This function updates:
00358     ///- g
00359     ///- old_g
00360     ///- h
00361     ///- old_h
00362     ///@param grad The derivatives of the function at \e x
00363     void update_vectors_PR(const Vector<Size>& grad)
00364     {
00365         //Update the position, gradient and conjugate directions
00366         old_g = g;
00367         old_h = h;
00368 
00369         g = grad;
00370         //Precision gamma = (g * g - oldg*g)/(oldg * oldg);
00371         Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00372         h = g + gamma * old_h;
00373         minus_h=-h;
00374     }
00375 
00376     ///Use this function to iterate over the optimization. Note that after
00377     ///iterate returns false, g, h, old_g and old_h will not have been
00378     ///updated.
00379     ///This function updates:
00380     /// - x
00381     /// - old_c
00382     /// - y
00383     /// - old_y
00384     /// - iterations
00385     /// - g*
00386     /// - old_g*
00387     /// - h*
00388     /// - old_h*
00389     /// *'d variables not updated on the last iteration.
00390     ///@param func Functor returning the function value at a given point.
00391     ///@param deriv Functor to compute derivatives at the specified point.
00392     ///@return Whether to continue.
00393     template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00394     {
00395         find_next_point(func);
00396 
00397         if(!finished())
00398         {
00399             update_vectors_PR(deriv(x));
00400             return 1;
00401         }
00402         else
00403             return 0;
00404     }
00405 };
00406 
00407 }