source/numeric/quadfitlinesearch.cxx

Go to the documentation of this file.
00001 /*************************************************************************
00002  *
00003  *  The Contents of this file are made available subject to
00004  *  the terms of GNU Lesser General Public License Version 2.1.
00005  *
00006  *
00007  *    GNU Lesser General Public License Version 2.1
00008  *    =============================================
00009  *    Copyright 2005-2008, by Kohei Yoshida.
00010  *    1039 Kingsway Dr., Apex, NC 27502, USA
00011  *
00012  *    This library is free software; you can redistribute it and/or
00013  *    modify it under the terms of the GNU Lesser General Public
00014  *    License version 2.1, as published by the Free Software Foundation.
00015  *
00016  *    This library is distributed in the hope that it will be useful,
00017  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00018  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019  *    Lesser General Public License for more details.
00020  *
00021  *    You should have received a copy of the GNU Lesser General Public
00022  *    License along with this library; if not, write to the Free Software
00023  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston,
00024  *    MA  02111-1307  USA
00025  *
00026  ************************************************************************/
00027 
00028 #include "numeric/quadfitlinesearch.hxx"
00029 #include "numeric/exception.hxx"
00030 #include "numeric/funcobj.hxx"
00031 #include "numeric/polyeqnsolver.hxx"
00032 #include "numeric/matrix.hxx"
00033 #include "numeric/diff.hxx"
00034 
00035 #include <stdio.h>
00036 #include <memory>
00037 #include <vector>
00038 #include <cmath>
00039 #include <exception>
00040 #include <string>
00041 #include <sstream>
00042 
00043 using ::std::auto_ptr;
00044 using ::std::vector;
00045 using ::std::exception;
00046 using ::std::swap;
00047 using ::std::fabs;
00048 using ::std::string;
00049 
00050 namespace scsolver { namespace numeric {
00051 
00052 class StepCalculationFailed : public exception
00053 {
00054 public:
00055     virtual const char* what() const throw()
00056     {
00057         return "step calculation failed during quadratic fit search";
00058     }
00059 };
00060 
00061 class InitialPointsNotFound : public exception
00062 {
00063 public:
00064     virtual const char* what() const throw()
00065     {
00066         return "initial 3 points not found";
00067     }
00068 };
00069 
00070 struct QuadFitSearchData
00071 {
00072     double P1;
00073     double P2;
00074     double P3;
00075 
00076     double StepLength;
00077     double XOffset;
00078 
00079     GoalType _GoalType;
00080     SingleVarFuncObj* pFunc;
00081 
00082     explicit QuadFitSearchData(SingleVarFuncObj* p) :
00083         StepLength(0.0),
00084         XOffset(0.0),
00085         pFunc(p)
00086     {
00087     }
00088 };
00089 
00090 enum ArmijoCheckStatus
00091 {
00095     Armijo_Unknown,
00096 
00101     Armijo_InitialCheckFailed,
00102 
00106     Armijo_AlphaCheckFailed,
00107 
00111     Armijo_Success
00112 };
00113 
00114 static ArmijoCheckStatus satisfiesArmijosRule(QuadFitSearchData& rData, double step, 
00115                                               const double e, const double alpha, 
00116                                               const double f0, const double ff0, 
00117                                               const double xoffset = 0.0,
00118                                               bool debug = false);
00119 
00136 static double findLargestStep(QuadFitSearchData& rData, double step, 
00137                               const double e, const double alpha, const double f0, const double ff0, const double xoffset)
00138 {
00139 //  fprintf(stdout, "numeric::findLargestStep: --begin (step = %g)\n", step);
00140 
00141     ArmijoCheckStatus status = satisfiesArmijosRule(rData, step, e, alpha, f0, ff0, false);
00142     if (status != Armijo_Success)
00143         // The initial step length already fails.  Just return the original 
00144         // step length.
00145         return step;
00146 
00147     double lastGoodStep = step;
00148     do
00149     {
00150         lastGoodStep = step;
00151         step *= alpha;
00152 
00153         rData.pFunc->setVar(step + xoffset);
00154         double testLeft = rData.pFunc->eval();
00155         double testRight = f0 + step*e*ff0;
00156 //      fprintf(stdout, "numeric::findLargestStep:   step = %g; f left = %g; f right = %g\n",
00157 //              step, testLeft, testRight);
00158 
00159         if (testLeft > testRight)
00160         {
00161 //          fprintf(stdout, "numeric::findLargestStep:   test failed \n");
00162             // the condition "f(step) <= f(0) + step * e * ff(0)" is not met.
00163             break;
00164         }
00165     }
00166     while (true);
00167 
00168 //  fprintf(stdout, "numeric::findLargestStep: --end (final step = %g)\n", lastGoodStep);
00169     return lastGoodStep;
00170 }
00171 
00186 static ArmijoCheckStatus satisfiesArmijosRule(QuadFitSearchData& rData, double step, 
00187                                               const double e, const double alpha, 
00188                                               const double f0, const double ff0,
00189                                               const double xoffset,
00190                                               bool debug)
00191 {
00192     if (debug)
00193         fprintf(stdout, "numeric::satisfiesArmijosRule: ---------- step = %g\n", step);
00194 
00195     // Check for f(step) <= f(0) + step * e * ff(0)
00196 
00197     rData.pFunc->setVar(step + xoffset);
00198     double testLeft = rData.pFunc->eval();
00199     double testRight = f0 + step*e*ff0;
00200     if (debug)
00201         fprintf(stdout, "numeric::satisfiesArmijosRule:   (initial check) step = %g; f left = %g; f right = %g\n",
00202                 step, testLeft, testRight);
00203 
00204     if (testLeft > testRight)
00205         // the condition "f(step) <= f(0) + step * e * ff(0)" is not met.
00206         return Armijo_InitialCheckFailed;
00207 
00208     // Next, check for f(alpha*step) > f(0) + alpha*step * e * ff(0)
00209 
00210     rData.pFunc->setVar(step*alpha + xoffset);
00211     testLeft = rData.pFunc->eval();
00212 
00213     testRight = f0 + ff0*step*alpha*e;
00214 
00215     if (debug)
00216         fprintf(stdout, "numeric::satisfiesArmijosRule:   (alpha check)   step*alpha = %g; f left = %g; f right = %g\n",
00217                 step*alpha, testLeft, testRight);
00218 
00219     if (testLeft <= testRight)
00220         return Armijo_AlphaCheckFailed;
00221 
00222     if (debug)
00223         fprintf(stdout, "numeric::satisfiesArmijosRule:   this satisfies Armijo's rule.\n");
00224 
00225     return Armijo_Success;
00226 }
00227 
00228 static void printArmijoStatus(ArmijoCheckStatus status)
00229 {
00230     switch (status)
00231     {
00232         case Armijo_AlphaCheckFailed:
00233             printf("  alpha check failed\n");
00234         break;
00235         case Armijo_InitialCheckFailed:
00236             printf("  initial check failed\n");
00237         break;
00238         case Armijo_Success:
00239             printf("  success\n");
00240         break;
00241         case Armijo_Unknown:
00242             printf("  unknown\n");
00243         break;
00244         default:
00245             printf("  other\n");
00246     }
00247 }
00248 
00249 static void calcStepLength(QuadFitSearchData& rData, bool debug)
00250 {
00251     double step = 10;
00252 
00253     double e = 0.2, alpha = 2.0;
00254 
00255     NumericalDiffer diff;
00256     diff.setFuncObject(rData.pFunc);
00257     diff.setPrecision(5);
00258 
00259     double xoffset = 0.0;
00260 
00261     // f'(0)
00262     diff.setVariable(xoffset);
00263     double ff0 = diff.run();
00264 
00265     if ((ff0 > 0 ? ff0 : -ff0) < 1.0e-16)
00266     {
00267         xoffset += 1.0;
00268         diff.setVariable(xoffset);
00269         ff0 = diff.run();
00270     }
00271 
00272     // f(0)
00273     rData.pFunc->setVar(xoffset);
00274     double f0 = rData.pFunc->eval();
00275 
00276     bool negativeDirection = false;
00277     if (ff0 > 0)
00278     {
00279         negativeDirection = true;
00280         step = -step;
00281     }
00282 
00283     if (debug)
00284         fprintf(stdout, "numeric::calcStepLength:   f(%g) = %g; ff(%g) = %g\n", xoffset, f0, xoffset, ff0);
00285 
00286     // loop until a first step value that satisfies Armijo's rule is found.
00287     ArmijoCheckStatus armijoStatus = Armijo_Unknown;
00288     bool checkNeverFailed = true;
00289     for (int i = 0; i < 4000000; ++i)
00290     {
00291         armijoStatus = satisfiesArmijosRule(rData, step, e, alpha, f0, ff0, xoffset, debug);
00292         if (debug)
00293             printArmijoStatus(armijoStatus);
00294         if (armijoStatus == Armijo_Success)
00295             break;
00296 
00297         checkNeverFailed = false;
00298 
00299         if (armijoStatus == Armijo_InitialCheckFailed)
00300         {
00301             step /= alpha;
00302             continue;
00303         }
00304 
00305         step *= alpha;
00306     }
00307 
00308     if (armijoStatus != Armijo_Success)
00309         throw StepCalculationFailed();
00310 
00311     if (checkNeverFailed)
00312         step = findLargestStep(rData, step, e, alpha, f0, ff0, xoffset);
00313 
00314     if (debug)
00315         fprintf(stdout, "numeric::calcStepLength: final step length = %g\n", step);
00316 
00317     rData.StepLength = step;
00318     rData.XOffset = xoffset;
00319 }
00320 
00329 static void findInitialPoints(QuadFitSearchData& data, bool debug)
00330 {
00331     using ::std::isfinite;
00332 
00333     const int maxIteration = 2000;
00334 
00335     // First, find an acceptable step length.
00336     calcStepLength(data, debug);
00337     double step = data.StepLength;
00338     if (debug)
00339         fprintf(stdout, "numeric::findInitialPoints:   calculated step length = %g\n", data.StepLength);
00340 
00341     SingleVarFuncObj& F = *data.pFunc;
00342 
00343     // Now, find three points such that the middle point has the lowest f(x) 
00344     // value.  Note that the step length can be negative if the slope at the
00345     // x-offset position is descending in the negative-x direction.
00346 
00347     // initial point is x1 = 0, then the first test point is xt = x1 + step.
00348     double x1 = data.XOffset, x2 = data.XOffset, x3 = data.XOffset;
00349     double f1 = F(x1);
00350     bool pointsFound = false;
00351     if (debug)
00352         fprintf(stdout, "numeric::findInitialPoints:   initial x1 = %g\n", x1);
00353     if (f1 <= F(x1 + step))
00354     {
00355         if (debug)
00356             fprintf(stdout, "numeric::findInitialPoints:   F(x1) <= F(x1 + step) - keep halving the x1 - x3 interval\n");
00357 
00358         x3 = x1 + step;
00359 
00360         // Now, keep halving the x1 - x3 interval.
00361         for (int i = 0; i < maxIteration; ++i)
00362         {
00363             x2 = (x1 + x3)/2.0;
00364             double f2 = F(x2);
00365             double f3 = F(x3);
00366             if (debug)
00367                 fprintf(stdout, "numeric::findInitialPoints:   f1 = %g; f2 = %g; f3 = %g\n", f1, f2, f3);
00368             if (!isfinite(f2) || !isfinite(f3))
00369             {
00370                 if (debug)
00371                     fprintf(stdout, "numeric::findInitialPoints:   either f2 or f3 is not a number.\n");
00372                 break;
00373             }
00374             if (f2 < f1 && f2 < f3)
00375             {
00376                 pointsFound = true;
00377                 break;
00378             }
00379 
00380             x3 = x2;
00381         }
00382     }
00383     else
00384     {
00385         if (debug)
00386             fprintf(stdout, "numeric::findInitialPoints:   F(x1) > F(x1 + step) : keep doubling the x1 - x2 interval\n");
00387 
00388         x2 = x1 + step;
00389 
00390         // Keep doubling the x1 - x2 interval.
00391         for (int i = 0; i < maxIteration; ++i)
00392         {
00393             x3 = x1 + (x2 - x1)*2;
00394             if (debug)
00395                 fprintf(stdout, "numeric::findInitialPoints:   iter %d: x1 = %g; x2 = %g; x3 = %g\n", i, x1, x2, x3);
00396             double f2 = F(x2);
00397             double f3 = F(x3);
00398             if (debug)
00399                 fprintf(stdout, "numeric::findInitialPoints:   f1 = %g; f2 = %g; f3 = %g\n", f1, f2, f3);
00400             if (!isfinite(f2) || !isfinite(f3))
00401             {
00402                 if (debug)
00403                     fprintf(stdout, "numeric::findInitialPoints:   either f2 or f3 is not a number.\n");
00404                 break;
00405             }
00406             if (f2 < f1 && f2 < f3)
00407             {
00408                 pointsFound = true;
00409                 break;
00410             }
00411 
00412             x2 = x3;
00413         }
00414     }
00415 
00416     if (!pointsFound)
00417         throw InitialPointsNotFound();
00418 
00419     // Optionally re-order the points so that x1 < x2 < x3.
00420     if (x1 > x3)
00421         ::std::swap(x1, x3);
00422 
00423     data.P1 = x1;
00424     data.P2 = x2;
00425     data.P3 = x3;
00426 }
00427 
00428 class PrefixedSingleVarFuncObj : public SingleVarFuncObj
00429 {
00430 public:
00431     explicit PrefixedSingleVarFuncObj(SingleVarFuncObj& rParent, double prefix) :
00432         m_rParent(rParent), m_prefix(prefix)
00433     {
00434     }
00435 
00436     virtual double eval() const
00437     {
00438         return m_prefix * m_rParent.eval();
00439     }
00440 
00441     virtual const::std::string getFuncString() const
00442     {
00443         ::std::ostringstream os;
00444         os << m_prefix << " * f(x) where f(x) = ";
00445         os << m_rParent.getFuncString();
00446         return os.str();
00447     }
00448 
00449     virtual double getVar() const
00450     {
00451         return m_rParent.getVar();
00452     }
00453 
00454     virtual void setVar(double var)
00455     {
00456         m_rParent.setVar(var);
00457     }
00458 
00459     void setPrefix(double prefix)
00460     {
00461         m_prefix = prefix;
00462     }
00463 
00464 private:
00465     SingleVarFuncObj& m_rParent;
00466     double m_prefix;
00467 };
00468 
00469 // --------------------------------------------------------------------------
00470 
00471 QuadFitLineSearch::QuadFitLineSearch() :
00472     BaseLineSearch(NULL),
00473     m_maxIteration(500)
00474 {
00475 }
00476 
00477 QuadFitLineSearch::QuadFitLineSearch(SingleVarFuncObj* pFuncObj) :
00478     BaseLineSearch(pFuncObj),
00479     m_maxIteration(500)
00480 {
00481 }
00482 
00483 QuadFitLineSearch::~QuadFitLineSearch()
00484 {
00485 }
00486 
00487 double QuadFitLineSearch::solve()
00488 {
00489     SingleVarFuncObj* pOrigFuncObj = getFuncObj();
00490     bool debug = isDebug();
00491 
00492     if (!pOrigFuncObj)
00493         throw Exception("QuadFitLineSearch::solve: no function object set");
00494 
00495     PrefixedSingleVarFuncObj prefixedFuncObj(*pOrigFuncObj, 1.0);
00496     if (getGoal() == GOAL_MAXIMIZE)
00497         prefixedFuncObj.setPrefix(-1.0);
00498 
00499     SingleVarFuncObj* pFuncObj = &prefixedFuncObj;
00500 
00501     if (debug)
00502         fprintf(stdout, "QuadFitLineSearch::solve:   function = %s\n", pFuncObj->getFuncString().c_str());
00503     QuadFitSearchData data(pFuncObj);
00504     data._GoalType = getGoal();
00505 
00506     // 1. Find three points such that the 2nd point be the lowest.
00507     // 2. Find the quadratic function that passes through all three points.
00508     // 3. Find the minimizer point of that quadratic function.
00509     // 4. Replace one of the three points with the minimizer point based on
00510     //    some conditions.
00511     // 5. Terminate when P1 - P3 < e.
00512 
00513     findInitialPoints(data, debug);
00514 
00515     if (debug)
00516         fprintf(stdout, "QuadFitLineSearch::solve:   initial points: p1 = %g; p2 = %g; p3 = %g\n",
00517                 data.P1, data.P2, data.P3);
00518 
00519     SingleVarFuncObj& F = *pFuncObj;
00520 
00521     const double eps = 1.0e-6; // tolerance limit (epsilon)
00522 
00523     bool solutionFound = false;
00524     double oldRange = data.P3 - data.P1;;
00525     for (size_t i = 0; i < m_maxIteration; ++i)
00526     {
00527         if (debug)
00528             fprintf(stdout, "QuadFitLineSearch::solve: ITERATION %d\n", i);
00529 
00530         // Solve the quadratic function.
00531         PolyEqnSolver eqnSolver;
00532         eqnSolver.addDataPoint(data.P1, F(data.P1));
00533         eqnSolver.addDataPoint(data.P2, F(data.P2));
00534         eqnSolver.addDataPoint(data.P3, F(data.P3));
00535         Matrix sol = eqnSolver.solve();
00536         if (debug)
00537         {
00538             fprintf(stdout, "QuadFitLineSearch::solve:   3-pt quad equation: ");
00539             sol.trans().print();
00540         }
00541     
00542         // Get the peak of that quad function.
00543         double x, y;
00544         getQuadraticPeak(x, y, sol);
00545         if (debug)
00546             fprintf(stdout, "QuadFitLineSearch::solve:   peak at (%g, %g)\n", x, y);
00547     
00548         if (data.P1 > x || x > data.P3)
00549         {
00550             ::std::ostringstream os;
00551             os << "calculated quadratic peak is not between P1 and P3";
00552             os << "  P1 = " << data.P1 << "  P3 = " << data.P3 << "  quadratic peak = " << x;
00553             throw Exception(os.str());
00554         }
00555 
00556         double delta = x - data.P2;
00557 
00558         if (debug)
00559             fprintf(stdout, "QuadFitLineSearch::solve:   delta = %g\n", delta);
00560 
00561         if ((delta > 0 ? delta : -delta) < 3.89e-15)
00562         {
00563             // case 3
00564             if (debug)
00565                 fprintf(stdout, "QuadFitLineSearch::solve:   case 3 (equality)\n");
00566 
00567             if (data.P3 - data.P1 < eps)
00568             {
00569                 solutionFound = true;
00570                 break;
00571             }
00572                 
00573             double l1 = data.P2 - data.P1, l2 = data.P3 - data.P2;
00574             if (l1 > l2)
00575             {
00576                 if (l1 > eps/2.0)
00577                     data.P2 -= eps*1.0e-1;
00578                 else
00579                     throw exception();
00580             }
00581             else
00582             {
00583                 if (l2 > eps/2.0)
00584                     data.P2 += eps*1.0e-1;
00585                 else
00586                     throw exception();
00587             }
00588         }
00589         else if (delta > 0)
00590         {
00591             // case 1 - test point is on the right side of P2.
00592             if (debug)
00593                 fprintf(stdout, "QuadFitLineSearch::solve:   case 1 (test > P2)\n");
00594 
00595             if (F(x) >= F(data.P2))
00596             {
00597                 data.P3 = x;
00598                 if (data.P2 - data.P3 < 3.89e-15)
00599                     data.P3 += eps*1.0e-1;
00600             }
00601             else
00602             {
00603                 swap(data.P1, data.P2);
00604                 data.P2 = x;
00605             }
00606         }
00607         else if (delta < 0)
00608         {
00609             // case 2 - test point is on the left side of P2.
00610             if (debug)
00611                 fprintf(stdout, "QuadFitLineSearch::solve:   case 2 (test < P2)\n");
00612 
00613             double valTest = F(x), valP2 = F(data.P2);
00614             if (debug)
00615             {
00616                 fprintf(stdout, "QuadFitLineSearch::solve:     F(test) = %g; F(P2) = %g (delta = %g)\n", 
00617                         valTest, valP2, valTest - valP2);
00618             }
00619 
00620             if (valTest >= valP2)
00621             {
00622                 if (debug)
00623                     fprintf(stdout, "QuadFitLineSearch::solve:     F(test) >= F(P2)\n");
00624 
00625                 data.P1 = x;
00626                 if (data.P1 - data.P2 < 3.89e-15)
00627                     data.P1 -= eps*1.0e-1;
00628             }
00629             else
00630             {
00631                 if (debug)
00632                     fprintf(stdout, "QuadFitLineSearch::solve:     F(test) < F(P2)\n");
00633 
00634                 swap(data.P2, data.P3);
00635                 data.P2 = x;
00636             }
00637         }
00638         else
00639         {
00640             throw Exception("unknown condition encountered");
00641         }
00642 
00643         double newRange = data.P3 - data.P1;
00644         if (debug)
00645         {
00646             fprintf(stdout, "QuadFitLineSearch::solve:   new points: p1 = %g; p2 = %g; p3 = %g (range = %g; delta = %g)\n",
00647                     data.P1, data.P2, data.P3, newRange, oldRange - newRange);
00648         }
00649 
00650         if (newRange < eps)
00651         {
00652             solutionFound = true;
00653             break;
00654         }
00655 
00656         double l1 = data.P2 - data.P1;
00657         double l2 = data.P3 - data.P2;
00658         if (debug)
00659             fprintf(stdout, "QuadFitLineSearch::solve:   l1 = %g; l2 = %g\n", data.P2 - data.P1, data.P3 - data.P2);
00660 
00661         if (l1 < eps/2.0 && l2 > eps/2.0)
00662         {
00663             // Length between P1 and P2 is less than half the epsion.  Push P2 away from P1 to make more room.
00664             data.P2 = data.P1 + eps/2.0;
00665             if (debug)
00666                 fprintf(stdout, "QuadFitLineSearch::solve:   l1 is less than half the epsilon\n");
00667         }
00668         else if (l1 > eps/2.0 && l2 < eps/2.0)
00669         {
00670             // Length between P2 and P3 is less than half the epsion.  Push P2 away from P3 to make more room.
00671             data.P2 = data.P3 - eps/2.0;
00672             if (debug)
00673                 fprintf(stdout, "QuadFitLineSearch::solve:   l2 is less than half the epsilon\n");
00674         }
00675 
00676         oldRange = newRange;
00677     }
00678 
00679     if (!solutionFound)
00680         throw MaxIterationReached();
00681 
00682     if (debug)
00683         fprintf(stdout, "QuadFitLineSearch::solve: final points: p1 = %g; p2 = %g; p3 = %g\n",
00684                 data.P1, data.P2, data.P3);
00685 
00686     return data.P2;
00687 }
00688 
00689 }}

Generated on Mon Jul 28 09:13:20 2008 for scsolver by  doxygen 1.5.3