00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "rootsolve.h"
00023
00024 #include <stdio.h>
00025
00026
00027 namespace NUM
00028 {
00029
00030
00031
00032 const PolyRootSolver_Sturm::Params PolyRootSolver_Sturm::defaults =
00033 {
00034
00035
00036 INIT_FIELD(nearly_zero) 1.0e-10,
00037 INIT_FIELD(relerror) 1.0e-14,
00038 INIT_FIELD(sbisect_maxit) 64,
00039 INIT_FIELD(modrf_maxit) 64,
00040 };
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051 int PolyRootSolver_Sturm::modrf(int ord,dbl *coef,dbl a,dbl b,
00052 dbl *val)
00053 {
00054
00055 dbl *scoef = coef;
00056 dbl *ecoef = &coef[ord];
00057 dbl fa,fb = fa = *ecoef;
00058 for(dbl *fp = ecoef - 1; fp >= scoef; fp--)
00059 {
00060 fa = a * fa + *fp;
00061 fb = b * fb + *fp;
00062 }
00063
00064
00065 if(fa * fb > 0.0) return(0);
00066
00067 if(fabs(fa) < par.relerror) { *val = a; return(1); }
00068 if(fabs(fb) < par.relerror) { *val = b; return(1); }
00069
00070 dbl lfx = fa;
00071 dbl fx;
00072 for(int its = 0; its<par.modrf_maxit; its++)
00073 {
00074 dbl x = (fb*a - fa*b) / (fb-fa);
00075
00076 fx = *ecoef;
00077 for(dbl *fp = ecoef-1; fp >= scoef; fp--)
00078 fx = x*fx + *fp;
00079
00080 if( (fabs(x) > par.relerror && fabs(fx/x) < par.relerror) ||
00081 fabs(fx) < par.relerror )
00082 {
00083 if(fabs(fx/x) < par.relerror)
00084 { *val=x; return(1); }
00085 }
00086
00087 if(fa*fx < 0)
00088 {
00089 b = x;
00090 fb = fx;
00091 if((lfx * fx) > 0) fa *= 0.5;
00092 }
00093 else
00094 {
00095 a = x;
00096 fa = fx;
00097 if((lfx * fx) > 0) fb *= 0.5;
00098 }
00099
00100 if(fabs(b-a) < par.relerror)
00101 { *val=x; return(1); }
00102
00103 lfx = fx;
00104 }
00105
00106
00107
00108
00109 return(0);
00110 }
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121 int PolyRootSolver_Sturm::modp(poly *u,poly *v,poly *r)
00122 {
00123 #if !PolyRootSolver_Sturm_BufOnStack
00124
00125
00126
00127 r->AllocCoeffs(u->ord);
00128 #endif
00129
00130 dbl *nr = r->coef;
00131 dbl *end = &u->coef[u->ord];
00132 dbl *uc = u->coef;
00133 while (uc <= end)
00134 *nr++ = *uc++;
00135
00136 if(v->coef[v->ord] < 0.0)
00137 {
00138 for(int k = u->ord - v->ord - 1; k >= 0; k -= 2)
00139 r->coef[k] = -r->coef[k];
00140
00141 for(int k = u->ord - v->ord; k >= 0; k--)
00142 for(int j = v->ord + k - 1; j >= k; j--)
00143 r->coef[j] = -r->coef[j] - r->coef[v->ord+k] * v->coef[j-k];
00144 }
00145 else
00146 {
00147 for(int k = u->ord - v->ord; k >= 0; k--)
00148 for(int j = v->ord + k - 1; j >= k; j--)
00149 r->coef[j] -= r->coef[v->ord+k] * v->coef[j-k];
00150 }
00151
00152 int k = v->ord - 1;
00153 while(k >= 0 && fabs(r->coef[k]) < par.nearly_zero) r->coef[k--]=0.0;
00154
00155 r->ord = (k<0) ? 0 : k;
00156
00157 return(r->ord);
00158 }
00159
00160
00161
00162
00163
00164
00165
00166
00167 int PolyRootSolver_Sturm::buildsturm(int ord,poly *sseq)
00168 {
00169 sseq[1].ord = ord-1;
00170 #if !PolyRootSolver_Sturm_BufOnStack
00171 sseq[1].AllocCoeffs(ord-1);
00172 #endif
00173
00174
00175 dbl f = fabs(sseq[0].coef[ord] * ord);
00176 dbl *fp = sseq[1].coef;
00177 dbl *fc = sseq[0].coef+1;
00178 for(int i=1; i<=ord; i++)
00179 *fp++ = *fc++ * i / f;
00180
00181
00182 poly *sp;
00183 for(sp = sseq+2; modp(sp-2,sp-1,sp); sp++)
00184 {
00185
00186 f = -fabs(sp->coef[sp->ord]);
00187 for(fp = &sp->coef[sp->ord]; fp >= sp->coef; fp--) *fp /= f;
00188 }
00189
00190 sp->coef[0] = -sp->coef[0];
00191
00192 return(sp-sseq);
00193 }
00194
00195
00196
00197
00198
00199
00200
00201 int PolyRootSolver_Sturm::numroots(int np,poly *sseq,int *atneg,int *atpos)
00202 {
00203 poly *s;
00204
00205
00206 int atposinf=0;
00207 dbl lf = sseq[0].coef[sseq[0].ord];
00208 #if 0
00209
00210 for(s = sseq+1; s <= sseq+np; s++)
00211 {
00212 dbl f = s->coef[s->ord];
00213 if(lf==0.0 || lf * f < 0) atposinf++;
00214 lf=f;
00215 }
00216 #else
00217 for(s=sseq+1; lf==0.0 && s <= sseq+np; s++) lf = s->coef[s->ord];
00218 for(; s<=sseq+np; s++)
00219 {
00220 dbl f = s->coef[s->ord];
00221 if(f==0) continue;
00222 if(f*lf<0) ++atposinf;
00223 lf=f;
00224 }
00225 #endif
00226 *atpos = atposinf;
00227
00228
00229 int atneginf=0;
00230 if(sseq[0].ord & 1) lf = -sseq[0].coef[sseq[0].ord];
00231 else lf = sseq[0].coef[sseq[0].ord];
00232 #if 0
00233
00234 for(s = sseq+1; s <= sseq+np; s++)
00235 {
00236 dbl f = (s->ord & 1) ? -s->coef[s->ord] : s->coef[s->ord];
00237 if(lf==0.0 || lf * f < 0) atneginf++;
00238 lf=f;
00239 }
00240 #else
00241 for(s=sseq+1; lf==0.0 && s <= sseq+np; s++)
00242 lf = (s->ord & 1) ? -s->coef[s->ord] : s->coef[s->ord];
00243 for(; s<=sseq+np; s++)
00244 {
00245 dbl f = (s->ord & 1) ? -s->coef[s->ord] : s->coef[s->ord];
00246 if(f==0) continue;
00247 if(f*lf<0) ++atneginf;
00248 lf=f;
00249 }
00250 #endif
00251 *atneg = atneginf;
00252
00253 #if 0
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264 int at0=0;
00265 lf = sseq[0].coef[0];
00266 #if 0
00267 for(poly *s = sseq+1; s <= sseq+np; s++)
00268 {
00269 dbl f = s->coef[0];
00270 if(lf==0.0 || lf * f < 0) ++at0;
00271 lf=f;
00272 }
00273 #else
00274 for(s=sseq+1; lf==0.0 && s <= sseq+np; s++) lf = s->coef[0];
00275 for(; s <= sseq+np; s++)
00276 {
00277 dbl f = s->coef[0];
00278 if(f==0) continue;
00279 if(f*lf<0) ++at0;
00280 lf=f;
00281 }
00282 #endif
00283 *at00=at0;
00284 #endif
00285
00286 return(atneginf-atposinf);
00287 }
00288
00289
00290
00291
00292
00293
00294
00295
00296 int PolyRootSolver_Sturm::numchanges(int np,poly *sseq,dbl a)
00297 {
00298 #if 0
00299 #warning "Buggy version enabled."
00300
00301
00302 int changes=0;
00303 dbl lf = PolvEval(a,sseq[0].coef,sseq[0].ord);
00304 for(poly *s = sseq+1; s<=sseq+np; s++)
00305 {
00306 dbl f = PolvEval(a,s->coef,s->ord);
00307 if(lf == 0.0 || lf*f < 0) changes++;
00308 lf = f;
00309 }
00310 #else
00311
00312
00313
00314 int changes=0;
00315 poly *s = sseq;
00316 dbl lf;
00317
00318 for(; s<=sseq+np; s++)
00319 {
00320 lf = PolvEval(a,s->coef,s->ord);
00321 if(lf!=0.0) break;
00322 }
00323
00324 for(; s<=sseq+np; s++)
00325 {
00326 dbl f = PolvEval(a,s->coef,s->ord);
00327 if(f==0.0) continue;
00328 if(f*lf<0) ++changes;
00329 lf=f;
00330 }
00331 #endif
00332
00333 return(changes);
00334 }
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344 int PolyRootSolver_Sturm::sbisect(int np,poly *sseq,dbl min,dbl max,
00345 int atmin,int atmax,dbl *roots)
00346 {
00347 if(atmin-atmax == 1)
00348 {
00349
00350 if(modrf(sseq->ord,sseq->coef,min,max,&roots[0])) return(1);
00351
00352
00353
00354 dbl mid;
00355 for(int its=0; its<par.sbisect_maxit; its++)
00356 {
00357
00358
00359
00360
00361
00362 mid = 0.5*(min+max);
00363
00364 if( (fabs(mid) > par.relerror && fabs((max-min)/mid) < par.relerror) ||
00365 fabs(max-min) < par.relerror )
00366 {
00367
00368
00369 dbl lv = PolvEval(min,sseq[0].coef,sseq[0].ord);
00370 dbl rv = PolvEval(max,sseq[0].coef,sseq[0].ord);
00371 roots[0] = min + (max-min)*lv/(lv-rv);
00372 return(1);
00373 }
00374
00375 int atmid = numchanges(np,sseq,mid);
00376
00377
00378 Assert(atmid>=atmax && atmid<=atmin);
00379
00380
00381
00382 if(atmin==atmid) min=mid;
00383 else max=mid;
00384 }
00385
00386 fprintf(stderr,"sbisect: overflow min %g max %g diff %g nroot %d\n",
00387 min,max,max-min,atmin-atmax);
00388
00389
00390
00391 dbl lv = PolvEval(min,sseq[0].coef,sseq[0].ord);
00392 dbl rv = PolvEval(max,sseq[0].coef,sseq[0].ord);
00393 roots[0] = (lv*rv<0) ? ( min + (max-min)*lv/(lv-rv) ) : mid;
00394
00395 return(1);
00396 }
00397
00398
00399 int n1,n2;
00400 dbl mid;
00401 for(int its=0; its<par.sbisect_maxit; its++)
00402 {
00403 mid = 0.5*(min+max);
00404 if( (fabs(mid) > par.relerror && fabs((max-min)/mid) < par.relerror) ||
00405 fabs(max-min) < par.relerror )
00406 {
00407
00408
00409 dbl lv = PolvEval(min,sseq[0].coef,sseq[0].ord);
00410 dbl rv = PolvEval(max,sseq[0].coef,sseq[0].ord);
00411 roots[0] = min + (max-min)*lv/(lv-rv);
00412 return(1);
00413 }
00414
00415 int atmid = numchanges(np,sseq,mid);
00416 Assert(atmid>=atmax && atmid<=atmin);
00417
00418
00419 n1 = atmin-atmid;
00420 n2 = atmid-atmax;
00421
00422 if(n1 && n2)
00423 {
00424
00425 n1=sbisect(np,sseq,min,mid,atmin,atmid,roots);
00426 n2=sbisect(np,sseq,mid,max,atmid,atmax,&roots[n1]);
00427 return(n1+n2);
00428 }
00429
00430 if(!n1) min=mid;
00431 else max=mid;
00432 }
00433
00434 fprintf(stderr, "sbisect: roots too close together\n");
00435 fprintf(stderr,
00436 "sbisect: overflow min %g max %g diff %g nroot %d n1 %d n2 %d\n",
00437 min, max, max - min, atmin-atmax, n1, n2);
00438
00439
00440
00441 dbl lv = PolvEval(min,sseq[0].coef,sseq[0].ord);
00442 dbl rv = PolvEval(max,sseq[0].coef,sseq[0].ord);
00443 if(lv*rv<0)
00444 { mid = min + (max-min)*lv/(lv-rv); }
00445
00446 #if 0
00447
00448
00449 int n2=atmin-atmax;
00450 for(n1=0; n1<n2; n1++)
00451 roots[n1] = mid;
00452 return(n2);
00453 #else
00454
00455
00456 roots[0]=mid;
00457 return(1);
00458 #endif
00459 }
00460
00461
00462 int PolyRootSolver_Sturm::solve(int order,const dbl *c,dbl *r)
00463 {
00464
00465 poly sseq[order+1];
00466 #if PolyRootSolver_Sturm_BufOnStack
00467 dbl buf_on_stack[(order+1)*(order+1)];
00468 for(int i=0; i<=order; i++)
00469 sseq[i].coef = &buf_on_stack[i*(order+1)];
00470 #else
00471 sseq[0].AllocCoeffs(order);
00472 #endif
00473
00474
00475 sseq[0].ord=order;
00476 for(int i=0; i<=order; i++)
00477 sseq[0].coef[i]=c[i];
00478
00479
00480 int np = buildsturm(order, sseq);
00481 Assert(np<=order);
00482
00483
00484 int atmin,atmax;
00485 int nroots = numroots(np,sseq,&atmin,&atmax);
00486 if(!nroots) return(0);
00487
00488
00489 dbl min,max;
00491
00492 dbl bracket_min=10,bracket_max=-10;
00493 #define MAXPOW 32
00494
00495
00496 if(bracket_min>bracket_max)
00497 {
00498 min = -1.0;
00499 int nchanges = numchanges(np,sseq,min);
00500 for(int i=0; nchanges!=atmin && i!=MAXPOW; i++)
00501 {
00502 min *= 10.0;
00503 nchanges = numchanges(np,sseq,min);
00504 }
00505 if(nchanges != atmin)
00506 {
00507 fprintf(stderr,"sturm solve: unable to bracket all negetive roots\n");
00508 atmin = nchanges;
00509 }
00510
00511 max = 1.0;
00512 nchanges = numchanges(np, sseq, max);
00513 for(int i=0; nchanges!=atmax && i!=MAXPOW; i++)
00514 {
00515 max *= 10.0;
00516 nchanges = numchanges(np,sseq,max);
00517 }
00518 if(nchanges != atmax)
00519 {
00520 fprintf(stderr,"sturm solve: unable to bracket all positive roots\n");
00521 atmax = nchanges;
00522 }
00523 }
00524 else
00525 {
00526 min=bracket_min;
00527 max=bracket_max;
00528
00529 atmin = numchanges(np,sseq,min);
00530 atmax = numchanges(np,sseq,max);
00531 }
00532
00533
00534 nroots = atmin-atmax;
00535 if(!nroots) return(0);
00536
00537
00538 nroots=sbisect(np,sseq,min,max,atmin,atmax,r);
00539 return(nroots);
00540 }
00541
00542
00543 PolyRootSolver_Sturm::PolyRootSolver_Sturm() :
00544 PolyRootSolver(),
00545 par(defaults)
00546 {
00547
00548 }
00549
00550 PolyRootSolver_Sturm::~PolyRootSolver_Sturm()
00551 {
00552
00553 }
00554
00555 }
00556
00557
00558 #if 0
00559
00560
00561
00562
00563
00564
00565 int main()
00566 {
00567
00568
00569 const int maxorder=11;
00570 dbl c[maxorder+1],r[maxorder];
00571 int cnt=0;
00572
00573 NUM::PolyRootSolver_Sturm solver;
00574
00575 for(int i=0; i<30000; i++)
00576 {
00577 int deg=rand()%(maxorder-3)+3;
00578 for(int j=0; j<=deg; j++)
00579 c[j]=(rand()%2000-1000)/100.0;
00580 int n=solver.solve(deg,c,r);
00581 for(int j=0; j<n; j++)
00582 {
00583 dbl R=NUM::PolvEval(r[j],c,deg);
00584 if(fabs(R)<1e-3) continue;
00585
00586
00587
00588
00589 ++cnt;
00590 }
00591 }
00592 fprintf(stderr,"cnt=%d\n",cnt);
00593 return(0);
00594 }
00595 #endif