00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "rootsolve.h"
00021
00022 #include <stdio.h>
00023
00024
00025 namespace NUM
00026 {
00027
00028
00029
00030 const PolyRootSolver_Bairstow::Params PolyRootSolver_Bairstow::defaults =
00031 {
00032 INIT_FIELD(nearly_zero) 1e-10,
00033 INIT_FIELD(maxiter) 500,
00034 INIT_FIELD(iter_chg) 200,
00035 INIT_FIELD(initial_epsilon) 1e-10,
00036 INIT_FIELD(disc_epsilon) 1e-10
00037 };
00038
00039
00040
00041 int PolyRootSolver_Bairstow::_QuadricRealRoots(dbl *a,int n,dbl *wr)
00042 {
00043 int m=n;
00044 int numroots=0;
00045 while(m>1)
00046 {
00047 dbl b2 = -0.5*a[m-2];
00048 dbl c = a[m-1];
00049
00050
00051 dbl disc = b2*b2-c;
00052 if(fabs(disc) <= par.disc_epsilon * (b2*b2+fabs(c)))
00053 {
00054
00055 wr[numroots++] = b2;
00056 }
00057 else if(disc>0.0)
00058 {
00059
00060 disc = sqrt(disc);
00061 disc = b2<0 ? (b2-disc) : (b2+disc);
00062 wr[numroots++] = disc;
00063 wr[numroots++] = c/disc;
00064 }
00065
00066 m -= 2;
00067 }
00068
00069 if(m==1)
00070 { wr[numroots++] = -a[0]; }
00071
00072 return(numroots);
00073 }
00074
00075
00076
00077 int PolyRootSolver_Bairstow::_QuadricRoots(dbl *a,int n,dbl *wr,
00078 dbl *wi)
00079 {
00080 dbl sq,b2,c,disc;
00081 int m,numroots;
00082
00083
00084
00085
00086 CritAssert(0);
00087
00088 m = n;
00089 numroots = 0;
00090 while (m > 1) {
00091 b2 = -0.5*a[m-2];
00092 c = a[m-1];
00093
00094 disc = b2*b2-c;
00095 if (fabs(disc) <= par.disc_epsilon*(b2*b2+fabs(c))) disc = 0.0;
00096 if (disc < 0.0)
00097 {
00098
00099 sq = sqrt(-disc);
00100 wr[m-2] = b2;
00101 wi[m-2] = sq;
00102 wr[m-1] = b2;
00103 wi[m-1] = -sq;
00104 numroots+=2;
00105 }
00106 else
00107 {
00108 sq = sqrt(disc);
00109
00110
00111
00112
00113
00114 wr[m-2] = b2<0 ? (b2-sq) : (b2+sq);
00115
00116 if(wr[m-2] == 0)
00117 {
00118
00119
00120
00121
00122
00123
00124
00125 wr[m-1] = 0;
00126 }
00127 else
00128 {
00129 wr[m-1] = c/wr[m-2];
00130 numroots+=2;
00131 }
00132
00133
00134 wi[m-2] = 0.0;
00135 wi[m-1] = 0.0;
00136 }
00137 m -= 2;
00138 }
00139 if (m == 1) {
00140 wr[0] = -a[0];
00141 wi[0] = 0.0;
00142 numroots++;
00143 }
00144 return numroots;
00145 }
00146
00147
00148
00149
00150 int PolyRootSolver_Bairstow::_FindQuadFact(dbl *a,int n,dbl *b,
00151 dbl *quad,int red_fact,dbl *err_est)
00152 {
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166 dbl c[n+1];
00167 c[0] = 1.0;
00168 dbl r = quad[1];
00169 dbl s = quad[0];
00170 dbl dr = 1.0;
00171 dbl ds = 0;
00172 dbl eps = par.initial_epsilon;
00173
00174 int iter=1;
00175
00176 do
00177 {
00178 if(iter>par.maxiter) break;
00179
00180
00181 if(!(iter % par.iter_chg))
00182 { eps*=10.0; }
00183
00184 b[1] = a[1] - r;
00185 c[1] = b[1] - r;
00186
00187 for(int i=2; i<=n; i++)
00188 {
00189 b[i] = a[i] - r * b[i-1] - s * b[i-2];
00190 c[i] = b[i] - r * c[i-1] - s * c[i-2];
00191 }
00192
00193 dbl drn=b[n] * c[n-3] - b[n-1] * c[n-2];
00194 dbl dsn=b[n-1] * c[n-1] - b[n] * c[n-2];
00195
00196 dbl dn=c[n-1] * c[n-3] - c[n-2] * c[n-2];
00197
00198 if(fabs(dn) < 1e-10)
00199 { dn = dn<0.0 ? -1e-8 : 1e-8; }
00200
00201
00202
00203 if(red_fact)
00204 {
00205 switch(red_fact)
00206 {
00207 case 1: dn*=0.999; break;
00208 case 2: dn*=1.1; break;
00209 default: dn*=0.8; break;
00210 }
00211 }
00212
00213 dn = 1.0/dn;
00214 dr = drn * dn;
00215 ds = dsn * dn;
00216 r += dr;
00217 s += ds;
00218
00219 ++iter;
00220
00221
00222 }
00223 while((fabs(dr)+fabs(ds)) > eps);
00224
00225 quad[0] = s;
00226 quad[1] = r;
00227 *err_est = fabs(ds)+fabs(dr);
00228
00229
00230
00231 return(iter);
00232 }
00233
00234
00235
00236
00237
00238 void PolyRootSolver_Bairstow::_GetQuads(dbl *a,int n,dbl *quad,dbl *x)
00239 {
00240 dbl tmp;
00241 if((tmp = a[0]) != 1.0)
00242 {
00243 a[0] = 1.0;
00244 for(int i=1; i<=n; i++)
00245 a[i] /= tmp;
00246 }
00247 if(n == 2)
00248 {
00249 x[0] = a[1];
00250 x[1] = a[2];
00251 return;
00252 }
00253 if(n == 1)
00254 {
00255 x[0] = a[1];
00256 return;
00257 }
00258
00259 dbl b[n+1];
00260 dbl z[n+1];
00261 b[0] = 1.0;
00262 for(int i=0; i<=n; i++)
00263 {
00264 z[i] = a[i];
00265 x[i] = 0.0;
00266 }
00267
00268 dbl err;
00269 int m = n;
00270 do
00271 {
00272 int red_fact=0;
00273
00274 if(n>m)
00275 {
00276 quad[0] = 3.14159e-1;
00277 quad[1] = 2.78127e-1;
00278 }
00279 loop:
00280 int iter=_FindQuadFact(z,m,b,quad,red_fact,&err);
00281 if( err>1e-7 || iter>par.maxiter )
00282 {
00283 _DiffPoly(z,m,b);
00284 iter = _Recurse(z,m,b,m-1,quad,&err);
00285 }
00286 err=_Deflate(z,m,b,quad);
00287 if(err > 0.01)
00288 {
00289 fprintf(stderr,"Bairstow: Excessive error %g (iter=%d, n=%d) %s\n",
00290 err,iter,n,red_fact<3 && iter ? "[re-try]" : "[giving up]");
00291 if(red_fact<3 && iter)
00292 {
00293
00294 ++red_fact;
00295 goto loop;
00296 }
00297
00298
00299
00300
00301
00302
00303
00304
00305 x[m-2] = 0;
00306 x[m-1] = 2;
00307 goto go_on;
00308 }
00309 if(err>1) goto loop;
00310 x[m-2] = quad[1];
00311 x[m-1] = quad[0];
00312 go_on:
00313 m -= 2;
00314 for(int i=0; i<=m; i++)
00315 z[i] = b[i];
00316 }
00317 while(m>2);
00318
00319 if(m==2)
00320 {
00321 x[0] = b[1];
00322 x[1] = b[2];
00323 }
00324 else x[0] = b[1];
00325 }
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348 int PolyRootSolver_Bairstow::_Recurse(dbl *a,int n,dbl *b,int m,
00349 dbl *quad,dbl *err)
00350 {
00351
00352 if(fabs(b[m]) < par.nearly_zero) --m;
00353
00354 if(m == 2)
00355 {
00356 quad[0] = b[2];
00357 quad[1] = b[1];
00358 *err = 0;
00359 return(0);
00360 }
00361
00362 dbl c[m+1];
00363 dbl x[n+1];
00364 c[0] = x[0] = 1.0;
00365 dbl rs[2];
00366 rs[0] = quad[0];
00367 rs[1] = quad[1];
00368 int iter=_FindQuadFact(b,m,c,rs,0,err);
00369 dbl tst = fabs(rs[0]-quad[0]) + fabs(rs[1]-quad[1]);
00370 if(*err < 1e-12)
00371 {
00372 quad[0] = rs[0];
00373 quad[1] = rs[1];
00374 }
00375
00376 if( (iter>5 && tst<1e-4) || (iter>20 && tst<1e-1) )
00377 {
00378 _DiffPoly(b,m,c);
00379 iter=_Recurse(a,n,c,m-1,rs,err);
00380 quad[0] = rs[0];
00381 quad[1] = rs[1];
00382 }
00383
00384 return(iter);
00385 }
00386
00387
00388
00389 void PolyRootSolver_Bairstow::_DiffPoly(dbl *a,int n,dbl *b)
00390 {
00391 dbl coef=n;
00392 b[0] = 1.0;
00393 for(int i=1; i<n; i++)
00394 b[i] = a[i]*(n-i)/coef;
00395 }
00396
00397
00398
00399
00400 dbl PolyRootSolver_Bairstow::_Deflate(dbl *a,int n,dbl *b,
00401 dbl *quad)
00402 {
00403 dbl c[n+1];
00404
00405 dbl r = quad[1];
00406 dbl s = quad[0];
00407
00408 b[1] = a[1] - r;
00409 c[1] = b[1] - r;
00410
00411 for(int i=2; i<=n; i++)
00412 {
00413 b[i] = a[i] - r * b[i-1] - s * b[i-2];
00414 c[i] = b[i] - r * c[i-1] - s * c[i-2];
00415 }
00416
00417 return( fabs(b[n])+fabs(b[n-1]) );
00418 }
00419
00420
00421
00422
00423 int PolyRootSolver_Bairstow::solve(int n,const dbl *_c,dbl *r)
00424 {
00425 int nroots=0;
00426
00427
00428
00429
00430 while(n>=0 && fabs(_c[n])<par.nearly_zero) --n;
00431 if(!n) return(nroots);
00432
00433 if(fabs(_c[0])<par.nearly_zero)
00434 {
00435
00436 r[nroots++]=0;
00437
00438
00439 int i=1;
00440 while(i<=n && fabs(_c[i])<par.nearly_zero) ++i;
00441 if(i>=n) return(nroots);
00442 _c+=i;
00443 n-=i;
00444 }
00445
00446
00447
00448 dbl c[n+1];
00449 for(int i=0; i<=n; i++) c[i]=_c[n-i];
00450
00451
00452 dbl quad[2];
00453 quad[0] = 2.71828e-1;
00454 quad[1] = 3.14159e-1;
00455
00456
00457 dbl x[n+1];
00458 _GetQuads(c,n,quad,x);
00459
00460
00461
00462
00463
00464 nroots+=_QuadricRealRoots(x,n,&r[nroots]);
00465
00466 return(nroots);
00467 }
00468
00469
00470 PolyRootSolver_Bairstow::PolyRootSolver_Bairstow() :
00471 PolyRootSolver(),
00472 par(defaults)
00473 {
00474
00475 }
00476
00477 PolyRootSolver_Bairstow::~PolyRootSolver_Bairstow()
00478 {
00479
00480 }
00481
00482 }
00483
00484
00485 #if 0
00486
00487 int main()
00488 {
00489 NUM::PolyRootSolver_Bairstow solver;
00490
00491 {
00492 dbl c[10],r[9];
00493 int deg;
00494 deg=4; c[4]=1; c[3]=-9; c[2]=29; c[1]=-39; c[0]=18;
00495
00496 int n=solver.solve(deg,c,r);
00497 fprintf(stderr,"%d roots:",n);
00498 for(int i=0; i<n; i++)
00499 { fprintf(stderr," %.15g",r[i]); }
00500 fprintf(stderr,"\n");
00501 }
00502
00503 {
00504 dbl c[10],r[9];
00505 int deg;
00506 deg=3; c[3]=3.96; c[2]=2.69; c[1]=-5.78; c[0]=2.24;
00507
00508 int n=solver.solve(deg,c,r);
00509 fprintf(stderr,"%d roots:",n);
00510 for(int i=0; i<n; i++)
00511 { fprintf(stderr," %.15g",r[i]); }
00512 fprintf(stderr,"\n");
00513 }
00514
00515
00516
00517 #if 1
00518 const int maxorder=11;
00519 dbl c[maxorder+1],r[maxorder];
00520 int cnt=0;
00521 for(int i=0; i<30000; i++)
00522 {
00523 int deg=rand()%(maxorder-3)+3;
00524 for(int j=0; j<=deg; j++)
00525 c[j]=(rand()%2000-1000)/100.0;
00526 int n=solver.solve(deg,c,r);
00527 for(int j=0; j<n; j++)
00528 {
00529 dbl R=NUM::PolvEval(r[j],c,deg);
00530 if(fabs(R)<1e-6) continue;
00531
00532 fprintf(stderr,"[%3d]",i);
00533 for(int k=0; k<=deg; k++) fprintf(stderr," %g",c[k]);
00534 fprintf(stderr,"-> %g for %g (n=%d, j=%d)\n",R,r[j],n,j);
00535 ++cnt;
00536 }
00537 }
00538 fprintf(stderr,"cnt=%d\n",cnt);
00539 #endif
00540
00541 return(0);
00542 }
00543 #endif