#! /usr/bin/python3
# Last edited on 2021-08-02 13:30:40 by jstolfi

import sys
from math import sqrt, hypot, sin, cos, exp, log, floor, ceil, inf, nan, pi;

import rn

from sampling_2d_grid_choose import sampling_2d_grid_choose
from weights_mesa import weights_mesa

from basis_2d_grid_choose import basis_2d_grid_choose

from basis_2d_gauss_eval import basis_2d_gauss_eval
from basis_2d_gauss_sample import basis_2d_gauss_sample
from basis_sampled_orthize import basis_sampled_orthize
from basis_sampled_project import basis_sampled_project
from basis_sampled_residuals_compute import basis_sampled_residuals_compute
from basis_sampled_residuals_combine import basis_sampled_residuals_combine

from basis_sampled_write import basis_sampled_write

def main():

  # Read data points:
  ??? pd, fd = ???
  nd = len(pd) # Number of data points.
  
  # Create grid-like Gaussian hump basis:
  c, r = fit_scattered_basis_2d_choose(pd)
  
  # Eval basis at data points:
  B = basis_2d_gauss_sample(c, r, pd)
  
  # Orthize basis at those data points:
  C = basis_sampled_orthize(B, None)
  
  # Project function on basis space:
  cf, fp, fr = basis_sampled_project(fd, C, None)
  
  # Choose sampling points to plot:
  
  
  # ----------------------------------------------------------------------

def fit_scattered_basis_2d_choose(pd):
  # Determines the grid centroids {c[..]} and radii {r[..]} for a Gaussian grid
  # approximation basis, given the list of data points {pd}.
  #
  # The resulting basis will have elements of the same radius
  # along the {x} and {y} axes, and will be a subset of a grid
  # spanning a rectangle that encloses the data points.
  
  nd = len(pd) # Number of data points.
  assert nd >= 4, "Needs at least 4 data points"
   
  # Find bounding box:
  bbox = None
  for p in pd:
    bbox = rn.box_include_point(bbox, p)
  xmin = bbox[0][0]; xmax = bbox[1][0]; xctr = (xmin + xmax)/2  
  ymin = bbox[0][1]; ymax = bbox[1][1]; yctr = (ymin + ymax)/2  
  dx = xmax - xmin
  dy = ymax - ymin
  rat = sqrt(dx/dy) # Aspect ratio of box.
  
  # Determine the nominal domain {[xmin _ xmax] × [ymin _ ymax]} of the basis
  # and the basis centroids {c} and radii {r}:
  nb_exp = 2*nd//3  # Initial guess for tot number of basis elements
  maxtry =3  # Max attempts to determine {nb}
  ntry = 0 # Number of attempts made.
  nb_ok = False
  while ntry < maxtry and not nb_ok:

    # Determine {nx,ny} so that {nx*ny ~ nb} and {nx/ny ~ dx/dy}.
    nx = max(2, int(ceil(sqrt(nb_exp)*rat)))
    ny = max(2, int(ceil(sqrt(nb_exp)/rat)))
    
    # Choose the max radius:
    rad = max(dx/(nx-1), dy/(ny-1))
    
    # Enlarge domain so that radii are equal:
    dx_a = (nx-1)*rad; xmin_a = xctr - dx_a/2; xmax_a = xctr + dx_a/2
    dy_a = (ny-1)*rad; ymin_a = yctr - dy_a/2; ymax_a = yctr + dy_a/2
    
    # Select the centroids and radii:
    span = True
    c, r = basis_2d_grid_choose(xmin_a,xmax_a,nx, ymin_a,ymax_a,ny, span)
    nb_full = len(c)
    assert nb_full == nx*ny
    
    # Eliminate basis elements whose nominal domain has no data points:
    nb_red = 0  # Number of elements in reduced basis.
    for k in rang(nb_full):
      hit = False
      assert abs(r[k][0] - rad) < 1.0e-12, "bug 0"
      assert abs(r[k][1] - rad) < 1.0e-12, "bug 1"
      xck = c[k][0]; yck = c[k][1];
      for p in pd:
        if abs(p[0] - xck) < rad and abs(p[1] - yck) < rad:
          hit = True; break
      if hit:
        c[nb_red] = c[k]; r[nb_red] = r[k];
        nb_red += 1
    assert nb_red >= 1, "bug 3"
    c = c[0:nb_red]
    r = r[0:nb_red]
    ntry += 1
    
    # Is the basis OK? 
    if nb_red > nd:
      # Reduce the target basis size:
      nb_exp = 2*nb_exp//3
    else
      nb_ok = True
      
  if not nb_ok:
    sys.stderr.write("!! warning: basis seems too big for given data points"
    
  return c, r
  # ----------------------------------------------------------------------
main()
