MODULE GradMin; IMPORT Wr, Thread, Text, Fmt, Math; FROM Stdio IMPORT stderr; PROCEDURE Minimize ( f: GoalFunc; VAR x: ARRAY OF REAL; (* In: starting point, Out: best point found. *) VAR y: REAL; (* Out: value of "f" at optimum. *) VAR dy: ARRAY OF REAL; (* Out: gradient of "f" at "x". *) maxCalls: CARDINAL; (* Maximum number of calls to "f". *) minStep: REAL := 1.0e-30; (* Minimum step size. *) maxStep: REAL := 1.0e+30; (* Maximum step size. *) report: ReportProc := NIL; (* Called when a new minimum is found. *) normalize: NormalizeProc := NIL; (* Called to modify point before applying "f". *) verbose: BOOLEAN := FALSE; (* TRUE to mumble while working. *) Alpha: REAL := 2.0; (* Step lengthening factor. *) Beta: REAL := 0.5; (* Step shortening factor. *) ) = BEGIN WITH N = NUMBER(x), v = NEW(REF ARRAY OF REAL, N)^, (* The current stepping direction *) xt = NEW(REF ARRAY OF REAL, N)^, (* A trial point *) dyt = NEW(REF ARRAY OF REAL, N)^ (* Gradient of "f" at "xt" *) DO VAR yt: REAL; (* Value of "f" at "xt" *) VAR step: REAL := 1.0; (* Current step size *) VAR calls: CARDINAL := 0; VAR gamma: REAL := 1.0/FLOAT(N); (* Gradient decay constant *) PROCEDURE CallF (VAR z: ARRAY OF REAL; VAR u: REAL; VAR du: ARRAY OF REAL) = (* Calls normalize on "z", then evaluates "f(z, u, du)". *) BEGIN IF normalize # NIL THEN normalize(z) END; INC(calls); f(z, u, du); END CallF; PROCEDURE DotProd(READONLY u, v: ARRAY OF REAL): LONGREAL = VAR s: LONGREAL := 0.0D0; BEGIN FOR i := 0 TO LAST(u) DO WITH uu = FLOAT(u[i], LONGREAL), vv = FLOAT(v[i], LONGREAL) DO s := s + uu*vv END END; RETURN s END DotProd; PROCEDURE ComputeNextTrial(): BOOLEAN = (* Computes "xt := x + step * v"; returns TRUE iff "xt" is distinct from "x". *) VAR diff: REAL := 0.0; BEGIN FOR i := 0 TO N-1 DO WITH xti = xt[i], xi = x[i] DO xti := xi + step * v[i]; diff := MAX(diff, ABS(xti - xi)) END; END; RETURN (diff > 0.0) END ComputeNextTrial; PROCEDURE QuadraticRoot(a, b, c: LONGREAL): LONGREAL = VAR r: LONGREAL; BEGIN (* This must be improved !!! *) WITH den = 2.0D0 * a, disc = b*b - 4.0D0 * a * c, sqd = Math.sqrt(MAX(disc, 0.0D0)), r1 = (-b - sqd)/den, r2 = (-b + sqd)/den DO <* ASSERT c < 0.0D0 *> <* ASSERT disc >= 0.0D0 *> IF r1 < 0.0D0 OR r1 > 1.0D0 THEN r := r2 ELSIF r2 < 0.0D0 OR r2 > 1.0D0 THEN r := r1 ELSIF r1 < r2 THEN r := r1 ELSE r := r2 END; <* ASSERT r >= 0.0D0 *> <* ASSERT r <= 1.0D0 *> RETURN r END END QuadraticRoot; PROCEDURE MinStep () = (* Searches along the line from "x" in the direction "v"; replaces "x", "y", and "dy" if it finds a better point. *) VAR sy: LONGREAL; (* Derivative of "f" at "x" along "v". *) VAR syt: LONGREAL; (* Derivative of "f" at "xt" along "v" *) VAR trimStep: REAL; (* Maximum step (to ensure progress) *) BEGIN sy := DotProd(v, dy); IF sy > 0.0d0 THEN Message("** direction flip:"); PrintVec(" v ", v); PrintVec(" dy ", dy); FOR i := 0 TO N-1 DO v[i] := -v[i] END; sy := -sy END; trimStep := step; LOOP IF NOT ComputeNextTrial() THEN Message("** null step = " & Fmt.Real(step)); RETURN END; CallF(xt, yt, dyt); IF verbose THEN PrintPoint("trial", xt, yt, dyt); END; IF yt < y THEN (* New minimum; update "x", "y", "dy", and "v". *) x := xt; y := yt; dy := dyt; FOR i := 0 TO LAST(v) DO WITH vi = v[i] DO vi := (1.0 - gamma) * vi - gamma * dy[i] END END; step := MIN(step * Alpha, maxStep); RETURN ELSIF step <= minStep OR trimStep < minStep THEN Message("** step too small"); step := minStep; RETURN ELSE syt := DotProd(v, dyt); (* Estimates position of minimum of "h(t) = f(x + t*step*v)", where "t" ranges over [0_1], by fitting a cubic "g(t)" to its end values "h0 = y" and "h1 = yt", and its end derivatives "dh0 = step*sy" and "dh1 = step*syt". The cubic is | g(t) = (dh0-2*h1+2*h0+dh1)*t**3+(-2*dh0-dh1-3*h0+3*h1)*t**2+dh0*t+h0 | dg(t) = (3*(dh0+dh1)-6*(h1-h0))*t**2+(6*(h1-h0)-4*dh0-2*dh1)*t+dh0 *) WITH h0 = FLOAT(y, LONGREAL), h1 = FLOAT(yt, LONGREAL), ss = FLOAT(step, LONGREAL), dh0 = ss * sy, dh1 = ss * syt, a = 3.0D0*(dh0 + dh1) - 6.0D0*(h1 - h0), b = 6.0D0*(h1 - h0) - 4.0D0*dh0 - 2.0D0*dh1, c = dh0, root = QuadraticRoot(a, b, c) DO step := MAX(minStep, FLOAT(root) * step); IF step > trimStep THEN Message("** step = " & Fmt.Real(step) & " trimmed"); step := trimStep END; trimStep := Beta * trimStep; (* Next time better make some progress... *) END END; END; END MinStep; PROCEDURE Message (msg: TEXT) = <* FATAL Wr.Failure, Thread.Alerted *> BEGIN Wr.PutText(stderr, msg); Wr.PutText(stderr, "\n"); END Message; PROCEDURE PrintVec ( msg: TEXT; READONLY w: ARRAY OF REAL; ) = <* FATAL Wr.Failure, Thread.Alerted *> BEGIN Wr.PutText(stderr, msg); Wr.PutChar(stderr, ' '); Wr.PutText(stderr, " = ( "); FOR j := 0 TO N-1 DO Wr.PutText(stderr, " "); Wr.PutText(stderr, Fmt.Pad(Fmt.Real(w[j]), 10)); END; Wr.PutText(stderr, " )\n"); END PrintVec; PROCEDURE PrintPoint ( msg: TEXT; READONLY z: ARRAY OF REAL; READONLY u: REAL; READONLY du: ARRAY OF REAL; ) = <* FATAL Wr.Failure, Thread.Alerted *> VAR M := Text.Length(msg); BEGIN Wr.PutText(stderr, msg); Wr.PutChar(stderr, ' '); Wr.PutText(stderr, "calls = " & Fmt.Pad(Fmt.Int(calls), 9)); Wr.PutText(stderr, " "); Wr.PutText(stderr, "step = " & Fmt.Pad(Fmt.Real(step), 12)); Wr.PutText(stderr, "\n"); FOR i := 0 TO M DO Wr.PutChar(stderr, ' ') END; Wr.PutText(stderr, "x = ( "); FOR j := 0 TO N-1 DO Wr.PutText(stderr, " "); Wr.PutText(stderr, Fmt.Pad(Fmt.Real(z[j]), 10)); END; Wr.PutText(stderr, " )\n"); FOR i := 0 TO M DO Wr.PutChar(stderr, ' ') END; Wr.PutText(stderr, "f = " & Fmt.Pad(Fmt.Real(u), 12)); Wr.PutText(stderr, "\n"); FOR i := 0 TO M DO Wr.PutChar(stderr, ' ') END; Wr.PutText(stderr, "df = ( "); FOR j := 0 TO N-1 DO Wr.PutText(stderr, " "); Wr.PutText(stderr, Fmt.Pad(Fmt.Real(du[j]), 10)); END; Wr.PutText(stderr, " )\n"); Wr.PutText(stderr, "\n"); END PrintPoint; VAR oldMin: REAL; BEGIN CallF(x, y, dy); IF report # NIL THEN IF report(x, y, dy) THEN RETURN END; END; FOR i := 0 TO LAST(v) DO v[i] := -dy[i] END; WHILE calls < maxCalls DO oldMin := y; MinStep (); IF y >= oldMin THEN RETURN END; IF report # NIL THEN IF report(x, y, dy) THEN RETURN END END; END END END END Minimize; BEGIN END GradMin.