#! /bin/usr/python3
# Test program for module {contact}
# Last edited on 2021-05-31 18:13:13 by jstolfi

import contact
import contact_example
import move
import move_parms
import path
import block
import block_example
import palette
import hacks
import job_parms
import rn
import pyx
import sys
from math import sqrt, sin, cos, floor, ceil, pi, nan, inf

parms = job_parms.typical_js()
parms['solid_raster_width'] = 1.00
parms['contour_trace_width'] = 0.50

mp_jump = move_parms.make_for_jumps(parms)
mp_cont = move_parms.make_for_contours(parms)
mp_fill = move_parms.make_for_fillings(parms)

wdf = move_parms.width(mp_fill)
wdc = move_parms.width(mp_cont)

def test_basic():

  sys.stderr.write("--- testing {from_moves,side} ---\n")
  
  ya = 1
  yb = ya + (wdf+wdc)/2
  yc = yb + wdc
  
  eps = 0.02
  
  pa0 = (  0, ya )
  qa0 = (  2, ya )
  
  pb0 = (  1, yb + 0.1*eps )
  qb0 = (  8, yb - 0.3*eps )
  pb1 = ( 10, yb )
  
  pc0 = (  0, yc )
  qc0 = (  2, yc )

  tra0 = move.make(pa0, qa0, mp_fill)
  trb0 = move.make(pb0, qb0, mp_cont)
  trc0 = move.make(pc0, qc0, mp_cont)
  
  jm_qb0_pb1 = move.make(qb0, pb1, mp_jump)
 
  for otra0 in tra0, move.rev(tra0):
    for otrb0 in trb0, move.rev(trb0):
      ctA = contact.from_moves(otra0, otrb0, 0.9, 0.49)
      assert ctA != None
      assert isinstance(ctA, contact.Contact)
      assert contact.side(ctA, 0) == tra0
      assert contact.side(ctA, 1) == trb0
      ptA = contact.endpoints(ctA)
      if ptA[0][0] > ptA[1][0]: ptA = (ptA[1], ptA[0])
      exA = ( (1, ya + wdf/2), (2, ya + wdf/2) )
      sys.stderr.write("ctA.endpoints = ( %.9f %.9f ) ( %.9f %.9f )" % (ptA[0]+ptA[1]))
      sys.stderr.write(" length = %.9f\n" % rn.dist(ptA[0], ptA[1]))
      sys.stderr.write("expected      = ( %.9f %.9f ) ( %.9f %.9f )\n" % (exA[0]+exA[1]))
      assert rn.dist(ptA[0], exA[0]) < eps
      assert rn.dist(ptA[1], exA[1]) < eps
 
  # Check length limits:
  ctAx = contact.from_moves(tra0, trb0, 1.1, 0.00)
  assert ctAx == None
  ctAy = contact.from_moves(tra0, trb0, 0.0, 0.51)
  assert ctAy == None

  ctB = contact.from_moves(tra0, trc0, 0, 0)
  assert ctB == None
 
  ctC = contact.from_moves(tra0, jm_qb0_pb1, 0, 0)
  assert ctC == None
  
  return
  # ----------------------------------------------------------------------

def test_names():

  sys.stderr.write("--- testing {set_name,get_name,tag_names} ---\n")
  
  ya = 1
  yb = ya + (wdf+wdc)/2
  yc = yb + wdc

  pa0 = (  0, ya )
  qa0 = (  2, ya )
  
  pb0 = (  1, yb )
  qb0 = (  8, yb )
  
  pc0 = (  0, yc )
  qc0 = (  2, yc )

  tra0 = move.make(pa0, qa0, mp_fill)
  trb0 = move.make(pb0, qb0, mp_cont)
  trc0 = move.make(pc0, qc0, mp_cont)

  ctA = contact.from_moves(tra0, trb0, 0.9, 0.49)
  assert ctA != None

  ctD = contact.from_moves(trb0, trc0, 0.9, 0.49)
  assert ctD != None

  assert contact.get_name(ctA) == "C?"
  contact.set_name(ctA, "Close")
  assert contact.get_name(ctA) == "Close"
  
  sys.stderr.write("applying {contact.tag_names}:\n")
  contact.tag_names([ctA, ctD], "Tag.")
  assert contact.get_name(ctA) == "Tag.Close"
  assert contact.get_name(ctD) == "Tag.C?"
  return
  # ----------------------------------------------------------------------

def test_more_makes():

  sys.stderr.write("--- testing {from_moves,from_move_lists,from_paths,from_blocks} ---\n")

  BCS,PHS,TRS0,TRS1 = block_example.misc_G(mp_cont, mp_fill, mp_jump)
  
  sys.stderr.write("  ... traces of block 0 ...\n")
  move.show_list(sys.stderr, TRS0, 2)

  sys.stderr.write("  ... traces of block 1 ...\n")
  move.show_list(sys.stderr, TRS1, 2)
  
  sys.stderr.write("  ... paths ...\n")
  path.show_list(sys.stderr, PHS, True, 2)
  
  sys.stderr.write("  ... blocks ...\n")
  block.show_list(sys.stderr, BCS, True, 2)
  
  ph0 = PHS[0]; nmv0 = path.nelems(ph0)
  ph1 = PHS[1]; nmv1 = path.nelems(ph1)
    
  bcA = BCS[0]
  bcB = BCS[1]
  
  szmin = 0.9
  rszmin = 0.19
  for ifun in range(5):
    if ifun == 0 or ifun == 1:
      MVS0 = [ path.elem(ph0,k) for k in range(nmv0) ]
      MVS1 = [ path.elem(ph1,k) for k in range(nmv1) ]
      if ifun == 0:
        sys.stderr.write("  ... {from_move_lists} ...\n")
        OMVS0 = [ mv for mv in MVS0 if not move.is_jump(mv) ]
        OMVS1 = [ mv for mv in MVS1 if not move.is_jump(mv) ]
      else:
        sys.stderr.write("  ... {from_move_lists} (reversed) ...\n")
        OMVS0 = [ move.rev(mv) for mv in MVS0 if not move.is_jump(mv) ]
        OMVS1 = [ move.rev(mv) for mv in MVS1 if not move.is_jump(mv) ]
      CTS = contact.from_move_lists(OMVS0, OMVS1, szmin, rszmin)
    elif ifun == 2:
      sys.stderr.write("  ... {from_paths} ...\n")
      CTS = contact.from_paths(ph0, ph1, szmin, rszmin) 
    elif ifun == 3:
      sys.stderr.write("  ... {from_paths} (reversed) ...\n")
      CTS = contact.from_paths(path.rev(ph0), path.rev(ph1), szmin, rszmin) 
    elif ifun == 4:
      sys.stderr.write("  ... {from_blocks} ...\n")
      CTS = contact.from_blocks(bcA, bcB, szmin, rszmin) 
    else:
      assert False
    contact.show_list(sys.stderr, CTS, 2)
    for k in range(len(CTS)):
      ctk = CTS[k]
      sys.stderr.write("\n")
      sys.stderr.write("  contact %d: " % k)
      contact.show(sys.stderr, ctk, 0, 4)
      sys.stderr.write("\n")
    assert len(CTS) == 6
    
    # Check contacts by names of sides:
    CTNS_obs = [ ( move.get_name(contact.side(ct, 0)), move.get_name(contact.side(ct, 1)) ) for ct in CTS ]
    CTNS_obs = list.sort(CTNS_obs)

    CTNS_exp = [
      ("TGa0", "TGb0"),
      ("TGc0", "TGb0"),
      ("TGa1", "TGb0"),
      ("TGa1", "TGb1"),
      ("TGc2", "TGb0"),
      ("TGc2", "TGb1"),
    ]
    CTNS_exp = list.sort(CTNS_exp)
    
    assert CTNS_obs == CTNS_exp

  return 
  # ----------------------------------------------------------------------

def test_show():

  sys.stderr.write("--- testing {show,show_list} ---\n")

  CTS, OPHS, TRS = contact_example.misc_B(mp_fill, mp_jump)
  
  sys.stderr.write("  ... {show} ...\n")
  wna = 5
  for k in range(3):
    ct = CTS[k]
    sys.stderr.write("[")
    contact.show(sys.stderr, ct, 4, wna)
    sys.stderr.write("]\n")
    wna = wna + 2
  sys.stderr.write("  ... {show_list} ...\n")
  contact.show_list(sys.stderr, CTS, 6)
  return 
  # ----------------------------------------------------------------------

def test_plot_to_files():

  sys.stderr.write("--- testing {plot_to_files} ---\n")

  tag = "plot_to_files"
  CTS, OPHS, TRS = contact_example.misc_B(mp_fill, mp_jump)
  nph = len(OPHS)

  CLRS = [ pyx.color.rgb(0.300, 0.600, 0.000), ]
  nclr = len(CLRS)

  rwd = 0.80
  wd_axes = 0.05*wdf   
  clr_ct = pyx.color.rgb.red # Color for contact lines

  for tics in (False, True):
    for ct_arrows in (False, True):
      if not (tics and ct_arrows):
        fname = ("tests/out/contact_TST_%s_tc%d_ar%d" % (tag,int(tics),int(ct_arrows)))
        contact.plot_to_files(fname, CTS, clr_ct, OPHS, CLRS, rwd, wd_axes, tics, ct_arrows)
  return
  # ----------------------------------------------------------------------

def test_plot_single():

  sys.stderr.write("--- testing {plot_single} ---\n")

  tag = "plot_single"
  CTS, OPHS, TRS = contact_example.misc_B(mp_fill, mp_jump)
  nph = len(OPHS)

  CLRS = hacks.trace_colors(nph)
  nclr = len(CLRS)

  # Get the enclosing box of the paths:
  B = path.bbox(OPHS)
  B = rn.box_join(B, contact.bbox(CTS))
  
  dp = (0,0)
  
  wd_axes = 0.05*wdf 
  rwd = 0.80
  wd_ct = 1.5*wd_axes
  clr_ct = pyx.color.rgb.red # Color for contact lines

  for tics in (False, True):
    for ct_arrows in (False, True):
      if not (tics and ct_arrows):
        c, szx,szy = hacks.make_canvas(B, dp, True, True, 1, 1)

        axes = False
        dots = True
        ph_arrows = True
        matter = False
        path.plot_standard(c, OPHS, None, None, CLRS, rwd, wd_axes, axes, dots, ph_arrows, matter)

        sz_tics = wd_ct if tics else 0
        nct = len(CTS)
        for k in range(nct):
          contact.plot_single(c, CTS[k], None, clr=clr_ct, wd=wd_ct, sz_tic=sz_tics, arrow=ct_arrows)

        hacks.write_plot(c, ("tests/out/contact_TST_%s_tc%d_ar%d" % (tag,int(tics),int(ct_arrows))))

  return
  # ----------------------------------------------------------------------

def test_endpoints_sides():
  sys.stderr.write("--- testing {endpoints,pmid,side,which_side,covindices} ---\n")
  
  CTS, OPHS, TRS = contact_example.misc_B(mp_fill, mp_jump)
  
  # Testing {pmid}
  q00, q01 = contact.endpoints(CTS[0])
  m1a = rn.mix(0.50, q00, 0.50, q01)
  m1b = contact.pmid(CTS[0])
  assert rn.dist(m1a, m1b) < 1.0e-8
 
  # This validation depends on the specific {OPHS,CTS} created above:
  for   ct,    mv0,   mv1,   oph0,   ix0, oph1,   ix1 in ( 
      ( CTS[0], TRS[0], TRS[1], OPHS[0], 0,   OPHS[2], 2 ),
      ( CTS[1], TRS[1], TRS[2], OPHS[2], 2,   OPHS[1], 0 ),
      ( CTS[2], TRS[0], TRS[2], OPHS[0], 0,   OPHS[1], 0 ),
      ( CTS[3], TRS[2], TRS[3], OPHS[1], 0,   OPHS[2], 0 ),
      ( CTS[4], TRS[1], TRS[4], OPHS[2], 2,   OPHS[1], 4 ),
      ( CTS[5], TRS[7], TRS[9], OPHS[3], 0,   OPHS[4], 0 ),
      ( CTS[6], TRS[8], TRS[9], OPHS[3], 2,   OPHS[4], 0 ),
      ( CTS[7], TRS[7], TRS[5], OPHS[3], 0,   OPHS[1], 2 ),
  ):
    sys.stderr.write("\n")
    contact.show(sys.stderr, ct, 2, 4)
    sys.stderr.write("\n")

    assert contact.side(ct, 0) == mv0
    assert contact.side(ct, 1) == mv1
    assert contact.which_side(mv0, ct) == 0
    assert contact.which_side(move.rev(mv0), ct) == 0
    assert contact.which_side(mv1, ct) == 1
    assert contact.which_side(move.rev(mv1), ct) == 1
    assert contact.which_side(TRS[6], ct) == None

    omv0 = path.elem(oph0, ix0)
    omv1 = path.elem(oph1, ix1)
    mv0a, drm0 = move.unpack(omv0)
    mv1a, drm1 = move.unpack(omv1)
    assert mv0a == mv0
    assert mv1a == mv1

    assert contact.covindices(oph0, ct) == ( ix0, None )
    assert contact.covindices(oph1, ct) == ( None, ix1 )

    ph0, drp0 = path.unpack(oph0)
    ph1, drp1 = path.unpack(oph1)
    if ph0 != ph1:
      nmv0 = path.nelems(oph0)
      nmv1 = path.nelems(oph1)
      gap = (path.pfin(oph0) != path.pini(oph1))
      use_jumps = True
      use_links = False
      oph01 = path.concat((oph0, oph1), use_jumps, use_links, mp_jump)
      assert path.nelems(oph01) == nmv0 + int(gap) + nmv1
      assert contact.covindices(oph01, ct) == ( ix0, nmv0+int(gap)+ix1 )

  return
  # ----------------------------------------------------------------------

def test_coverage_times():
  sys.stderr.write("--- testing {covtime,covtimes,tcool,max_tcool,min_tcov} ---\n")

  CTS, OPHS, TRS = contact_example.misc_B(mp_fill, mp_jump)
  
  for dr in 0, 1:
    # Variable {dr} defines the orientation of the two test paths.
    for od in 0, 1:
      # Variable {od} defines the order of the two test paths.
      # sys.stderr.write("- - - - - - - - - - - - - - - - - - - - - - - - \n")
      # sys.stderr.write("od = %d  dr = %d\n" % (od,dr))
      nclosed = 0
      nactive = 0
      ninactive = 0
      # Make a test path from two of the paths:
      if od == 0:
        oph0 = path.spin(OPHS[1],dr)
        oph1 = path.spin(OPHS[2],dr)
      else:
        oph0 = path.spin(OPHS[2],dr)
        oph1 = path.spin(OPHS[1],dr)
      use_jump = True
      use_link = False
      oph = path.concat([oph0, oph1,], use_jump, use_link, mp_jump)
      assert path.nelems(oph) == path.nelems(oph0) + 1 + path.nelems(oph1)

      # Compute the connector time:
      omv_prev = move.rev(path.elem(path.rev(oph0),0))
      omv_next = path.elem(oph1,0)
      tconn = move.connector_extime(omv_prev, omv_next, use_jump, use_link, mp_jump)
      assert abs((path.extime(oph0) + tconn + path.extime(oph1)) - path.extime(oph)) < 1.0e-8
      
      # Compute {max_tcool,pair_max_tcool,min_tcov} by hand:
      maxtcool = -inf
      pmaxtcool = -inf
      mintcov = +inf
      for ct in CTS:
        ixs = contact.covindices(oph, ct)
        tcs = contact.covtimes(oph, ct)
        tc0 = contact.covtime(oph, ixs[0], ct, 0); assert tcs[0] == tc0
        tc1 = contact.covtime(oph, ixs[1], ct, 1); assert tcs[1] == tc1
        # sys.stderr.write("tcs = %s\n" % str(tcs))
        for i in range(2):
          # Check {covindices} and {covtimes}: 
          imv = ixs[i]
          if imv == None:
            assert path.find_move(oph, contact.side(ct, i)) == None
            assert tcs[i] == None
          else:
            omvk = path.elem(oph, imv)
            mvk, drk = move.unpack(omvk)
            assert mvk == contact.side(ct, i)
            if drk == 0:
              assert tcs[i] == path.tini(oph,imv) + contact.tcov(ct, i)
            else:
              assert tcs[i] == path.tfin(oph,imv) - contact.tcov(ct, i)
            assert tcs[i] <= path.extime(oph)
        # Recompute {max_tcool,min_tcov} from scratch:
        tcool = contact.tcool(oph, ct)
        if tcs[0] == None and tcs[1] == None:
          # Contact is inactive rel to {oph}:
          assert tcool == None
          ninactive += 1
        elif tcs[0] == None or tcs[1] == None:
          # Contact is active rel to {oph}:
          assert tcool == None
          tci = tcs[0] if tcs[0] != None else tcs[1]
          if tci < mintcov: mintcov = tci
          nactive += 1
        else:
          # Contact is closed by {oph}:
          # sys.stderr.write("tcs: %s tcool: %.6f\n" % (str(tcs),tcool))
          assert tcool == abs(tcs[0] - tcs[1])
          if tcool > maxtcool: maxtcool = tcool
          nclosed += 1
        ixs0 = contact.covindices(oph0, ct)
        ixs1 = contact.covindices(oph1, ct)
        ptcool = contact.pair_tcool(oph0, tconn, oph1, ct)
        if ixs0[0] == None or ixs1[1] == None:
          assert ptcool == None
        else:
          assert abs(ptcool - tcool) < 1.0e-8
          if ptcool > pmaxtcool: pmaxtcool = ptcool
      # These tests are specific for the {create_test} above:
      assert nclosed == 3
      assert nactive == 3
      assert ninactive == 2
      assert contact.max_tcool(oph, CTS) == maxtcool
      assert contact.pair_max_tcool(oph0, tconn, oph1, CTS) == pmaxtcool
      assert contact.min_tcov(oph, CTS) == mintcov

  return
  # ----------------------------------------------------------------------

test_basic()
test_more_makes()
test_show()
test_endpoints_sides()
test_coverage_times()
test_plot_single()
test_plot_to_files()
