00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include "numeric/quadfitlinesearch.hxx"
00029 #include "numeric/exception.hxx"
00030 #include "numeric/funcobj.hxx"
00031 #include "numeric/polyeqnsolver.hxx"
00032 #include "numeric/matrix.hxx"
00033 #include "numeric/diff.hxx"
00034
00035 #include <stdio.h>
00036 #include <memory>
00037 #include <vector>
00038 #include <cmath>
00039 #include <exception>
00040 #include <string>
00041 #include <sstream>
00042
00043 using ::std::auto_ptr;
00044 using ::std::vector;
00045 using ::std::exception;
00046 using ::std::swap;
00047 using ::std::fabs;
00048 using ::std::string;
00049
00050 namespace scsolver { namespace numeric {
00051
00052 class StepCalculationFailed : public exception
00053 {
00054 public:
00055 virtual const char* what() const throw()
00056 {
00057 return "step calculation failed during quadratic fit search";
00058 }
00059 };
00060
00061 class InitialPointsNotFound : public exception
00062 {
00063 public:
00064 virtual const char* what() const throw()
00065 {
00066 return "initial 3 points not found";
00067 }
00068 };
00069
00070 struct QuadFitSearchData
00071 {
00072 double P1;
00073 double P2;
00074 double P3;
00075
00076 double StepLength;
00077 double XOffset;
00078
00079 GoalType _GoalType;
00080 SingleVarFuncObj* pFunc;
00081
00082 explicit QuadFitSearchData(SingleVarFuncObj* p) :
00083 StepLength(0.0),
00084 XOffset(0.0),
00085 pFunc(p)
00086 {
00087 }
00088 };
00089
00090 enum ArmijoCheckStatus
00091 {
00095 Armijo_Unknown,
00096
00101 Armijo_InitialCheckFailed,
00102
00106 Armijo_AlphaCheckFailed,
00107
00111 Armijo_Success
00112 };
00113
00114 static ArmijoCheckStatus satisfiesArmijosRule(QuadFitSearchData& rData, double step,
00115 const double e, const double alpha,
00116 const double f0, const double ff0,
00117 const double xoffset = 0.0,
00118 bool debug = false);
00119
00136 static double findLargestStep(QuadFitSearchData& rData, double step,
00137 const double e, const double alpha, const double f0, const double ff0, const double xoffset)
00138 {
00139
00140
00141 ArmijoCheckStatus status = satisfiesArmijosRule(rData, step, e, alpha, f0, ff0, false);
00142 if (status != Armijo_Success)
00143
00144
00145 return step;
00146
00147 double lastGoodStep = step;
00148 do
00149 {
00150 lastGoodStep = step;
00151 step *= alpha;
00152
00153 rData.pFunc->setVar(step + xoffset);
00154 double testLeft = rData.pFunc->eval();
00155 double testRight = f0 + step*e*ff0;
00156
00157
00158
00159 if (testLeft > testRight)
00160 {
00161
00162
00163 break;
00164 }
00165 }
00166 while (true);
00167
00168
00169 return lastGoodStep;
00170 }
00171
00186 static ArmijoCheckStatus satisfiesArmijosRule(QuadFitSearchData& rData, double step,
00187 const double e, const double alpha,
00188 const double f0, const double ff0,
00189 const double xoffset,
00190 bool debug)
00191 {
00192 if (debug)
00193 fprintf(stdout, "numeric::satisfiesArmijosRule: ---------- step = %g\n", step);
00194
00195
00196
00197 rData.pFunc->setVar(step + xoffset);
00198 double testLeft = rData.pFunc->eval();
00199 double testRight = f0 + step*e*ff0;
00200 if (debug)
00201 fprintf(stdout, "numeric::satisfiesArmijosRule: (initial check) step = %g; f left = %g; f right = %g\n",
00202 step, testLeft, testRight);
00203
00204 if (testLeft > testRight)
00205
00206 return Armijo_InitialCheckFailed;
00207
00208
00209
00210 rData.pFunc->setVar(step*alpha + xoffset);
00211 testLeft = rData.pFunc->eval();
00212
00213 testRight = f0 + ff0*step*alpha*e;
00214
00215 if (debug)
00216 fprintf(stdout, "numeric::satisfiesArmijosRule: (alpha check) step*alpha = %g; f left = %g; f right = %g\n",
00217 step*alpha, testLeft, testRight);
00218
00219 if (testLeft <= testRight)
00220 return Armijo_AlphaCheckFailed;
00221
00222 if (debug)
00223 fprintf(stdout, "numeric::satisfiesArmijosRule: this satisfies Armijo's rule.\n");
00224
00225 return Armijo_Success;
00226 }
00227
00228 static void printArmijoStatus(ArmijoCheckStatus status)
00229 {
00230 switch (status)
00231 {
00232 case Armijo_AlphaCheckFailed:
00233 printf(" alpha check failed\n");
00234 break;
00235 case Armijo_InitialCheckFailed:
00236 printf(" initial check failed\n");
00237 break;
00238 case Armijo_Success:
00239 printf(" success\n");
00240 break;
00241 case Armijo_Unknown:
00242 printf(" unknown\n");
00243 break;
00244 default:
00245 printf(" other\n");
00246 }
00247 }
00248
00249 static void calcStepLength(QuadFitSearchData& rData, bool debug)
00250 {
00251 double step = 10;
00252
00253 double e = 0.2, alpha = 2.0;
00254
00255 NumericalDiffer diff;
00256 diff.setFuncObject(rData.pFunc);
00257 diff.setPrecision(5);
00258
00259 double xoffset = 0.0;
00260
00261
00262 diff.setVariable(xoffset);
00263 double ff0 = diff.run();
00264
00265 if ((ff0 > 0 ? ff0 : -ff0) < 1.0e-16)
00266 {
00267 xoffset += 1.0;
00268 diff.setVariable(xoffset);
00269 ff0 = diff.run();
00270 }
00271
00272
00273 rData.pFunc->setVar(xoffset);
00274 double f0 = rData.pFunc->eval();
00275
00276 bool negativeDirection = false;
00277 if (ff0 > 0)
00278 {
00279 negativeDirection = true;
00280 step = -step;
00281 }
00282
00283 if (debug)
00284 fprintf(stdout, "numeric::calcStepLength: f(%g) = %g; ff(%g) = %g\n", xoffset, f0, xoffset, ff0);
00285
00286
00287 ArmijoCheckStatus armijoStatus = Armijo_Unknown;
00288 bool checkNeverFailed = true;
00289 for (int i = 0; i < 4000000; ++i)
00290 {
00291 armijoStatus = satisfiesArmijosRule(rData, step, e, alpha, f0, ff0, xoffset, debug);
00292 if (debug)
00293 printArmijoStatus(armijoStatus);
00294 if (armijoStatus == Armijo_Success)
00295 break;
00296
00297 checkNeverFailed = false;
00298
00299 if (armijoStatus == Armijo_InitialCheckFailed)
00300 {
00301 step /= alpha;
00302 continue;
00303 }
00304
00305 step *= alpha;
00306 }
00307
00308 if (armijoStatus != Armijo_Success)
00309 throw StepCalculationFailed();
00310
00311 if (checkNeverFailed)
00312 step = findLargestStep(rData, step, e, alpha, f0, ff0, xoffset);
00313
00314 if (debug)
00315 fprintf(stdout, "numeric::calcStepLength: final step length = %g\n", step);
00316
00317 rData.StepLength = step;
00318 rData.XOffset = xoffset;
00319 }
00320
00329 static void findInitialPoints(QuadFitSearchData& data, bool debug)
00330 {
00331 using ::std::isfinite;
00332
00333 const int maxIteration = 2000;
00334
00335
00336 calcStepLength(data, debug);
00337 double step = data.StepLength;
00338 if (debug)
00339 fprintf(stdout, "numeric::findInitialPoints: calculated step length = %g\n", data.StepLength);
00340
00341 SingleVarFuncObj& F = *data.pFunc;
00342
00343
00344
00345
00346
00347
00348 double x1 = data.XOffset, x2 = data.XOffset, x3 = data.XOffset;
00349 double f1 = F(x1);
00350 bool pointsFound = false;
00351 if (debug)
00352 fprintf(stdout, "numeric::findInitialPoints: initial x1 = %g\n", x1);
00353 if (f1 <= F(x1 + step))
00354 {
00355 if (debug)
00356 fprintf(stdout, "numeric::findInitialPoints: F(x1) <= F(x1 + step) - keep halving the x1 - x3 interval\n");
00357
00358 x3 = x1 + step;
00359
00360
00361 for (int i = 0; i < maxIteration; ++i)
00362 {
00363 x2 = (x1 + x3)/2.0;
00364 double f2 = F(x2);
00365 double f3 = F(x3);
00366 if (debug)
00367 fprintf(stdout, "numeric::findInitialPoints: f1 = %g; f2 = %g; f3 = %g\n", f1, f2, f3);
00368 if (!isfinite(f2) || !isfinite(f3))
00369 {
00370 if (debug)
00371 fprintf(stdout, "numeric::findInitialPoints: either f2 or f3 is not a number.\n");
00372 break;
00373 }
00374 if (f2 < f1 && f2 < f3)
00375 {
00376 pointsFound = true;
00377 break;
00378 }
00379
00380 x3 = x2;
00381 }
00382 }
00383 else
00384 {
00385 if (debug)
00386 fprintf(stdout, "numeric::findInitialPoints: F(x1) > F(x1 + step) : keep doubling the x1 - x2 interval\n");
00387
00388 x2 = x1 + step;
00389
00390
00391 for (int i = 0; i < maxIteration; ++i)
00392 {
00393 x3 = x1 + (x2 - x1)*2;
00394 if (debug)
00395 fprintf(stdout, "numeric::findInitialPoints: iter %d: x1 = %g; x2 = %g; x3 = %g\n", i, x1, x2, x3);
00396 double f2 = F(x2);
00397 double f3 = F(x3);
00398 if (debug)
00399 fprintf(stdout, "numeric::findInitialPoints: f1 = %g; f2 = %g; f3 = %g\n", f1, f2, f3);
00400 if (!isfinite(f2) || !isfinite(f3))
00401 {
00402 if (debug)
00403 fprintf(stdout, "numeric::findInitialPoints: either f2 or f3 is not a number.\n");
00404 break;
00405 }
00406 if (f2 < f1 && f2 < f3)
00407 {
00408 pointsFound = true;
00409 break;
00410 }
00411
00412 x2 = x3;
00413 }
00414 }
00415
00416 if (!pointsFound)
00417 throw InitialPointsNotFound();
00418
00419
00420 if (x1 > x3)
00421 ::std::swap(x1, x3);
00422
00423 data.P1 = x1;
00424 data.P2 = x2;
00425 data.P3 = x3;
00426 }
00427
00428 class PrefixedSingleVarFuncObj : public SingleVarFuncObj
00429 {
00430 public:
00431 explicit PrefixedSingleVarFuncObj(SingleVarFuncObj& rParent, double prefix) :
00432 m_rParent(rParent), m_prefix(prefix)
00433 {
00434 }
00435
00436 virtual double eval() const
00437 {
00438 return m_prefix * m_rParent.eval();
00439 }
00440
00441 virtual const::std::string getFuncString() const
00442 {
00443 ::std::ostringstream os;
00444 os << m_prefix << " * f(x) where f(x) = ";
00445 os << m_rParent.getFuncString();
00446 return os.str();
00447 }
00448
00449 virtual double getVar() const
00450 {
00451 return m_rParent.getVar();
00452 }
00453
00454 virtual void setVar(double var)
00455 {
00456 m_rParent.setVar(var);
00457 }
00458
00459 void setPrefix(double prefix)
00460 {
00461 m_prefix = prefix;
00462 }
00463
00464 private:
00465 SingleVarFuncObj& m_rParent;
00466 double m_prefix;
00467 };
00468
00469
00470
00471 QuadFitLineSearch::QuadFitLineSearch() :
00472 BaseLineSearch(NULL),
00473 m_maxIteration(500)
00474 {
00475 }
00476
00477 QuadFitLineSearch::QuadFitLineSearch(SingleVarFuncObj* pFuncObj) :
00478 BaseLineSearch(pFuncObj),
00479 m_maxIteration(500)
00480 {
00481 }
00482
00483 QuadFitLineSearch::~QuadFitLineSearch()
00484 {
00485 }
00486
00487 double QuadFitLineSearch::solve()
00488 {
00489 SingleVarFuncObj* pOrigFuncObj = getFuncObj();
00490 bool debug = isDebug();
00491
00492 if (!pOrigFuncObj)
00493 throw Exception("QuadFitLineSearch::solve: no function object set");
00494
00495 PrefixedSingleVarFuncObj prefixedFuncObj(*pOrigFuncObj, 1.0);
00496 if (getGoal() == GOAL_MAXIMIZE)
00497 prefixedFuncObj.setPrefix(-1.0);
00498
00499 SingleVarFuncObj* pFuncObj = &prefixedFuncObj;
00500
00501 if (debug)
00502 fprintf(stdout, "QuadFitLineSearch::solve: function = %s\n", pFuncObj->getFuncString().c_str());
00503 QuadFitSearchData data(pFuncObj);
00504 data._GoalType = getGoal();
00505
00506
00507
00508
00509
00510
00511
00512
00513 findInitialPoints(data, debug);
00514
00515 if (debug)
00516 fprintf(stdout, "QuadFitLineSearch::solve: initial points: p1 = %g; p2 = %g; p3 = %g\n",
00517 data.P1, data.P2, data.P3);
00518
00519 SingleVarFuncObj& F = *pFuncObj;
00520
00521 const double eps = 1.0e-6;
00522
00523 bool solutionFound = false;
00524 double oldRange = data.P3 - data.P1;;
00525 for (size_t i = 0; i < m_maxIteration; ++i)
00526 {
00527 if (debug)
00528 fprintf(stdout, "QuadFitLineSearch::solve: ITERATION %d\n", i);
00529
00530
00531 PolyEqnSolver eqnSolver;
00532 eqnSolver.addDataPoint(data.P1, F(data.P1));
00533 eqnSolver.addDataPoint(data.P2, F(data.P2));
00534 eqnSolver.addDataPoint(data.P3, F(data.P3));
00535 Matrix sol = eqnSolver.solve();
00536 if (debug)
00537 {
00538 fprintf(stdout, "QuadFitLineSearch::solve: 3-pt quad equation: ");
00539 sol.trans().print();
00540 }
00541
00542
00543 double x, y;
00544 getQuadraticPeak(x, y, sol);
00545 if (debug)
00546 fprintf(stdout, "QuadFitLineSearch::solve: peak at (%g, %g)\n", x, y);
00547
00548 if (data.P1 > x || x > data.P3)
00549 {
00550 ::std::ostringstream os;
00551 os << "calculated quadratic peak is not between P1 and P3";
00552 os << " P1 = " << data.P1 << " P3 = " << data.P3 << " quadratic peak = " << x;
00553 throw Exception(os.str());
00554 }
00555
00556 double delta = x - data.P2;
00557
00558 if (debug)
00559 fprintf(stdout, "QuadFitLineSearch::solve: delta = %g\n", delta);
00560
00561 if ((delta > 0 ? delta : -delta) < 3.89e-15)
00562 {
00563
00564 if (debug)
00565 fprintf(stdout, "QuadFitLineSearch::solve: case 3 (equality)\n");
00566
00567 if (data.P3 - data.P1 < eps)
00568 {
00569 solutionFound = true;
00570 break;
00571 }
00572
00573 double l1 = data.P2 - data.P1, l2 = data.P3 - data.P2;
00574 if (l1 > l2)
00575 {
00576 if (l1 > eps/2.0)
00577 data.P2 -= eps*1.0e-1;
00578 else
00579 throw exception();
00580 }
00581 else
00582 {
00583 if (l2 > eps/2.0)
00584 data.P2 += eps*1.0e-1;
00585 else
00586 throw exception();
00587 }
00588 }
00589 else if (delta > 0)
00590 {
00591
00592 if (debug)
00593 fprintf(stdout, "QuadFitLineSearch::solve: case 1 (test > P2)\n");
00594
00595 if (F(x) >= F(data.P2))
00596 {
00597 data.P3 = x;
00598 if (data.P2 - data.P3 < 3.89e-15)
00599 data.P3 += eps*1.0e-1;
00600 }
00601 else
00602 {
00603 swap(data.P1, data.P2);
00604 data.P2 = x;
00605 }
00606 }
00607 else if (delta < 0)
00608 {
00609
00610 if (debug)
00611 fprintf(stdout, "QuadFitLineSearch::solve: case 2 (test < P2)\n");
00612
00613 double valTest = F(x), valP2 = F(data.P2);
00614 if (debug)
00615 {
00616 fprintf(stdout, "QuadFitLineSearch::solve: F(test) = %g; F(P2) = %g (delta = %g)\n",
00617 valTest, valP2, valTest - valP2);
00618 }
00619
00620 if (valTest >= valP2)
00621 {
00622 if (debug)
00623 fprintf(stdout, "QuadFitLineSearch::solve: F(test) >= F(P2)\n");
00624
00625 data.P1 = x;
00626 if (data.P1 - data.P2 < 3.89e-15)
00627 data.P1 -= eps*1.0e-1;
00628 }
00629 else
00630 {
00631 if (debug)
00632 fprintf(stdout, "QuadFitLineSearch::solve: F(test) < F(P2)\n");
00633
00634 swap(data.P2, data.P3);
00635 data.P2 = x;
00636 }
00637 }
00638 else
00639 {
00640 throw Exception("unknown condition encountered");
00641 }
00642
00643 double newRange = data.P3 - data.P1;
00644 if (debug)
00645 {
00646 fprintf(stdout, "QuadFitLineSearch::solve: new points: p1 = %g; p2 = %g; p3 = %g (range = %g; delta = %g)\n",
00647 data.P1, data.P2, data.P3, newRange, oldRange - newRange);
00648 }
00649
00650 if (newRange < eps)
00651 {
00652 solutionFound = true;
00653 break;
00654 }
00655
00656 double l1 = data.P2 - data.P1;
00657 double l2 = data.P3 - data.P2;
00658 if (debug)
00659 fprintf(stdout, "QuadFitLineSearch::solve: l1 = %g; l2 = %g\n", data.P2 - data.P1, data.P3 - data.P2);
00660
00661 if (l1 < eps/2.0 && l2 > eps/2.0)
00662 {
00663
00664 data.P2 = data.P1 + eps/2.0;
00665 if (debug)
00666 fprintf(stdout, "QuadFitLineSearch::solve: l1 is less than half the epsilon\n");
00667 }
00668 else if (l1 > eps/2.0 && l2 < eps/2.0)
00669 {
00670
00671 data.P2 = data.P3 - eps/2.0;
00672 if (debug)
00673 fprintf(stdout, "QuadFitLineSearch::solve: l2 is less than half the epsilon\n");
00674 }
00675
00676 oldRange = newRange;
00677 }
00678
00679 if (!solutionFound)
00680 throw MaxIterationReached();
00681
00682 if (debug)
00683 fprintf(stdout, "QuadFitLineSearch::solve: final points: p1 = %g; p2 = %g; p3 = %g\n",
00684 data.P1, data.P2, data.P3);
00685
00686 return data.P2;
00687 }
00688
00689 }}