MODULE JSUniMin;

IMPORT Wr, Thread, Fmt;
FROM Stdio IMPORT stderr;

PROCEDURE Minimize(
    f: Function;    (* Function to minimize *)
    VAR x: REAL;    (* In: starting guess. Out: best point found. *)
    VAR fx: REAL;   (* In/Out: "f(x)". *)
    VAR dfx: REAL;  (* Out: estimate of first deriv. of "f" at "x". *)
    VAR ddfx: REAL; (* Out: estimate of second deriv. of "f" at "x". *)
    step: REAL;     (* Starting step. *)
    minStep: REAL;  (* Minimum step. *)
    maxStep: REAL;  (* Maximum step. *)
    maxCalls: CARDINAL;        (* Maximum number of calls to "f" *)
    report: ReportProc := NIL; (* Client procedure to report progress. *)
    verbose: BOOLEAN := FALSE; (* TRUE to mumble while working. *)
    Alpha: REAL := 0.01;     (* Min splitting ratio. *)
  ) = 
  VAR 
    (* Probe points: *)
    a: REAL := x - step; fa: REAL := f(a);  (* Low point. *)
    b: REAL := x + step; fb: REAL := f(b);  (* High point. *)
    m: REAL := x; fm: REAL := fx;           (* Middle point. *)

    nCalls: CARDINAL := 2;                  (* Counts calls to "f". *)
    dfm: LONGREAL; ddfm: LONGREAL;          (* Derivatives at "m". *)
    s: REAL; fs: REAL;                      (* Splitting point. *)

    fmOld: REAL := fm;                      (* Previous value of "fm" *)
    
    Beta: REAL := (1.0-Alpha)/Alpha;
    Ateb: REAL := Alpha/(1.0-Alpha);

  PROCEDURE ComputeDerivatives () =
    (* Computes derivatives of "f" at "m" by fitting a parabola. *)
    BEGIN
      WITH
        aL = FLOAT(a, LONGREAL), faL = FLOAT(fa, LONGREAL),
        bL = FLOAT(b, LONGREAL), fbL = FLOAT(fb, LONGREAL),
        mL = FLOAT(m, LONGREAL), fmL = FLOAT(fm, LONGREAL),
        wab = bL - aL,
        wam = mL - aL,
        wmb = bL - mL,
        da = (fmL - faL)/wam,
        db = (fbL - fmL)/wmb,
        ra = wam/wab,
        rb = wmb/wab
      DO
        dfm := da*rb + db*ra;
        ddfm := 2.0d0 * (db - da)/(bL - aL);
      END;
    END ComputeDerivatives;

    PROCEDURE Message (msg: TEXT) =
      <* FATAL Wr.Failure, Thread.Alerted *>
      BEGIN
        Wr.PutText(stderr, msg);
        Wr.PutText(stderr, "\n");
      END Message;

    PROCEDURE PrintStatus () =
      <* FATAL Wr.Failure, Thread.Alerted *>
      BEGIN 
        Wr.PutText(stderr, "calls = ");
        Wr.PutText(stderr, Fmt.Pad(Fmt.Int(nCalls), 9));
        Wr.PutText(stderr, "\n");
        PrintPoint("a", a, fa);
        PrintPoint("m", m, fm, dfm, ddfm);
        PrintPoint("b", b, fb);
        Wr.PutText(stderr, "\n");
      END PrintStatus;

    PROCEDURE PrintPoint (msg: TEXT; u, fu: REAL; dfu, ddfu: LONGREAL := 0.0d0) =
      <* FATAL Wr.Failure, Thread.Alerted *>
      BEGIN 
        Wr.PutText(stderr, "    ");
        Wr.PutText(stderr, msg);
        Wr.PutText(stderr, " = ");
        Wr.PutText(stderr, Fmt.Pad(Fmt.Real(u), 12));
        Wr.PutText(stderr, " f" & msg & " = ");
        Wr.PutText(stderr, Fmt.Pad(Fmt.Real(fu), 12));
        IF dfu # 0.0d0 OR ddfu # 0.0d0 THEN
          Wr.PutText(stderr, " df" & msg & " = ");
          Wr.PutText(stderr, Fmt.Pad(Fmt.LongReal(dfu), 18));
          Wr.PutText(stderr, " ddf" & msg & " = ");
          Wr.PutText(stderr, Fmt.Pad(Fmt.LongReal(ddfu), 18));
        END;
        Wr.PutText(stderr, "\n");
      END PrintPoint;

  BEGIN
    <* ASSERT minStep > 0.0 *>
    <* ASSERT step >= minStep *>
    <* ASSERT step <= maxStep *>
    <* ASSERT Alpha < 0.5 *>

    LOOP
      (* Loop invariants:
          1. Probe points are definitely ordered:
            "a + minStep <= m <= b - minStep"
          2. Gaps aren't too wide:
            "a + maxStep >= m >= b - maxStep"
          3. Gaps aren't too unequal:
            "MIN(m-a, b-m) >= Alpha * (b-a)"
      *)

      (* Computes the derivatives of "f" at "m": *)
      ComputeDerivatives();
      
      IF verbose THEN PrintStatus() END;
      
      (* Update client variables: *)
      x := m; fx := fm; dfx := FLOAT(dfm); ddfx := FLOAT(ddfm);
      
      (* Report new minimum and abort if good enough: *)
      IF fm < fmOld THEN
        IF report # NIL AND report(x, fx, dfx, ddfx) THEN RETURN END;
        fmOld := fm;
      END;
      
      (* Check budget: *)
      IF nCalls >= maxCalls THEN RETURN END;
      
      (* Check for termination criterion: *)
      IF fa >= fm AND fb >= fm
      AND (a + minStep > m - minStep AND m + minStep > b - minStep) THEN
        RETURN
      END;
      
      (* Estimate minimum position: *)
      IF ddfm <= 0.0d0 THEN
        IF verbose THEN Message("convex triple") END;
        IF dfm < 0.0d0 THEN
          s := LAST(REAL)
        ELSIF dfm > 0.0d0 THEN
          s := FIRST(REAL)
        ELSE
          s := m
        END
      ELSE
        s := FLOAT(FLOAT(m, LONGREAL) - dfm/ddfm)
      END;
      
      (* Adjust estimate to preserve loop invariants: *)
      IF fa < MIN(fm, fb) THEN
        (* Probe before "a" *)
        s := MIN(a - minStep, MAX(s, a - MAX(maxStep, Beta * (m - a))));
      ELSIF fb < MIN(fm, fa) THEN
        (* Probe after "b" *)
        s := MAX(b + minStep, MIN(s, b + MIN(maxStep, Beta * (b - m))));
      ELSE
        WITH 
          amLo = a + MAX(minStep, Alpha*(m-a)),
          amHi = m - MAX(minStep, MAX(Ateb*(b-m), Alpha*(m-a))),
          mbLo = m + MAX(minStep, MAX(Ateb*(m-a), Alpha*(b-m))),
          mbHi = b - MAX(minStep, Alpha*(b-m))
        DO
          (* Decide whether to split "a..m" or "m..b": *)
          IF s <= m THEN
            IF amLo <= amHi THEN
              s := MAX(amLo, MIN(amHi, s))
            ELSIF mbLo <= mbHi THEN
              IF verbose THEN Message("splitting wrong side") END;
              s := mbLo
            ELSE
              <* ASSERT FALSE *>
            END
          ELSE
            IF mbLo <= mbHi THEN
              s := MAX(mbLo, MIN(mbHi, s))
            ELSIF amLo <= amHi THEN
              IF verbose THEN Message("splitting wrong side") END;
              s := amHi
            ELSE
              <* ASSERT FALSE *>
            END
          END;
        END;
      END;

      (* Evaluate new point: *)
      fs := f(s); INC(nCalls);

      (* Discard one point and rename others: *)
      IF s < a THEN
        IF verbose THEN Message("extraploating down") END;
        b := m; fb := fm;
        m := a; fm := fa;
        a := s; fa := fs;
      ELSIF s > b THEN
        IF verbose THEN Message("extraploating up") END;
        a := m; fa := fm;
        m := b; fm := fb;
        b := s; fb := fs;
      ELSIF fs < fm THEN
        IF s <= m THEN 
          IF verbose THEN Message("new minimum in [a__m]") END;
          b := m; fb := fm;
        ELSE 
          IF verbose THEN Message("new minimum in [m__b]") END;
          a := m; fa := fm;
        END;
        m := s; fm := fs;
      ELSE
        IF s <= m THEN 
          IF verbose THEN Message("high point in [a__m]") END;
          a := s; fa := fs;
        ELSE 
          IF verbose THEN Message("high point in [m__b]") END;
          b := s; fb := fs;
        END
      END

    END
  END Minimize;

BEGIN
END JSUniMin.
