# Implementation of module {contact}
# Last edited on 2021-05-15 15:09:12 by jstolfi

import contact
import move
import move_parms
import path
import block
import hacks
import rn
import pyx 
from math import nan, inf, sqrt
import sys

class Contact_IMP:
  def __init__(self, p0, p1, mv0, tc0, mv1, tc1, bc0, bc1, rl0, rl1):
    # Read-only fields:
    self.pts = (p0, p1)
    self.mv = (mv0, mv1)
    self.tcov = (tc0, tc1)
    self.bc = (bc0, bc1)
    self.rl = (rl0, rl1)


def make(p0, p1, mv0, mv1, bc0, bc1, rl0, rl1):
  assert hacks.is_point(p0); p0 = tuple(p0) # Make sure it is immutable.
  assert hacks.is_point(p1); p1 = tuple(p1) # Make sure it is immutable.
  assert isinstance(mv0, move.Move) and not move.is_jump(mv0)
  assert isinstance(mv1, move.Move) and not move.is_jump(mv1)
  assert mv0 != mv1
  m = rn.mix(0.5, p0, 0.5, p1)
  tc0 = move.cover_time(mv0, m)
  tc1 = move.cover_time(mv1, m)
  return contact.Contact(p0, p1, mv0, tc0, mv1, tc1, bc0, bc1, rl0, rl1)

def endpoints(ct):
  assert isinstance(ct, contact.Contact)
  return ct.pts

def pmid(ct):
  return rn.mix(0.5, ct.pts[0], 0.5, ct.pts[1])

def get_raster_link(ct):
  return ct.rl

def side_block(ct, i):
  return ct.bc[i]

def side(ct, i):
  return ct.mv[i]

def tcov(ct, i):
  return ct.tcov[i]

def which_side(mv, ct):
  assert isinstance(mv, move.Move)
  for i in range(2):
    if ct.mv[i] == mv:
      return i
  return None 
 
# COVERAGE BY PATHS

def covindices(oph, ct):
  ixs = [None, None]
  n = path.nelems(oph)
  for imv in range(n):
    omv = path.elem(oph, imv)
    mv, dr = move.unpack(omv)
    for i in range(2):
      if side(ct, i) == mv:
        assert ixs[i] == None, "repeated move in path"
        ixs[i] = imv
  return tuple(ixs)

def covtime(oph, imv, ct, i):
  assert isinstance(ct, contact.Contact)
  ph, dr_ph = path.unpack(oph) # For the typechecking.
  if imv == None:
    tc = None
  else:
    mv = side(ct, i)
    omv = path.elem(oph, imv)
    mv, dr = move.unpack(omv)
    if dr == 0:
      tc =  path.tini(oph, imv) + ct.tcov[i]
    else:
      tc = path.tfin(oph, imv) - ct.tcov[i]
  return tc

def covtimes(oph, ct):
  ixs = covindices(oph, ct)
  assert len(ixs) == 2
  tcs = [ None, None ]
  for i in range(2):
    imv = ixs[i]
    if imv != None:
      omv = path.elem(oph, imv)
      mv, dr = move.unpack(omv)
      if dr == 0:
        tcs[i] =  path.tini(oph, imv) + ct.tcov[i]
      else:
        tcs[i] = path.tfin(oph, imv) - ct.tcov[i]
  return tuple(tcs)

def tcool(oph, ct):
  tcs = covtimes(oph, ct)
  assert type(tcs) is list or type(tcs) is tuple
  assert len(tcs) == 2
  if tcs[0] != None and tcs[1] != None:
    return abs(tcs[0] - tcs[1])
  else:
    return None

def pair_tcool(oph0, tconn, oph1, ct):
  # Get the two sides of {ct}:
  imv0 = path.find(oph0,side(ct,0))
  imv1 = path.find(oph1,side(ct,1))
  if imv0 == None or imv1 == None:
    tc = None
  else:
    tc0 = covtime(oph0,imv0,ct,0)
    tc1 = covtime(oph1,imv1,ct,1)
    tc1 += path.extime(oph0) + tconn
    assert tc1 >= tc0
    tc = tc1 - tc0
  return tc
  # ----------------------------------------------------------------------

# TOOLS FOR LISTS OF CONTACTS

def max_tcool(oph, CTS):
  assert type(CTS) is list or type(CTS) is tuple
  tmax = -inf
  for ct in CTS:
    tc = tcool(oph, ct)
    if tc != None and tc > tmax: 
      tmax = tc
  return tmax
  # ----------------------------------------------------------------------

def pair_max_tcool(oph0, tconn, oph1, CTS):
  assert type(CTS) is list or type(CTS) is tuple
  tmax = -inf
  for ct in CTS:
    tc = pair_tcool(oph0, tconn, oph1, ct)
    if tc != None and tc > tmax: 
      tmax = tc
  return tmax
  # ----------------------------------------------------------------------
  
def min_tcov(oph, CTS):
  assert type(CTS) is list or type(CTS) is tuple
  tmin = +inf
  for ct in CTS:
    tcs = covtimes(oph, ct)
    if (tcs[0] == None) != (tcs[1] == None):
      tci = tcs[0] if tcs[0] != None else tcs[1]
      if tci < tmin: tmin = tci
  return tmin

# PLOTTING

def plot(c, ct, dp, clr, wd_line, sz_tic, arrow): 
  p = ct.pts[0]
  q = ct.pts[1]
  dpq = rn.dist(p,q)
  peps = 0.01*wd_line if dpq < 1.0e-6 else 0 # Perturbation for equal points.
  sty_basic = [
    pyx.style.linecap.round, 
    clr,
  ]
  sty_line = sty_basic + [ pyx.style.linewidth(wd_line), ]
  if dp != None: sty_line.append(pyx.trafo.translate(dp[0], dp[1]))
  c.stroke(pyx.path.line(p[0]-peps, p[1]-peps, q[0]+peps, q[1]+peps), sty_line)

  # Should we plot a transversal tic or arrowhead?
  if sz_tic == None: sz_tic = 0 # Simplification.
  if sz_tic > 0 or arrow:
    # Plot the transversal tic or arrowhead:
    m = rn.mix(0.5, p, 0.5, q)  # Midpoint.
    u = get_perp_dir(m, ct.mv[0], ct.mv[1])
    sz_arrow = 3*wd_line if arrow else 0
    # We need a tic with a certain min size for the arrowhead:
    sz_tic = max(sz_tic, 0.80*sz_arrow)
    a = rn.mix(1.0, m, -0.5*sz_tic, u)
    b = rn.mix(1.0, m, +0.5*sz_tic, u)
    sty_tic = sty_basic
    if sz_arrow > 0:
      # Add the arrowhead to the tic.
      arrowpos = 0.5  # Position of arrow on transversal line.
      wd_arrow = sz_arrow/5 # Linewidth for stroking the arrowhead (guess).
      sty_arrow = sty_basic + [ 
        pyx.deco.stroked([pyx.style.linewidth(wd_arrow), pyx.style.linejoin.round]),
        pyx.deco.filled([])
      ]
      sty_tic = sty_tic + \
        [ 
          pyx.deco.earrow(sty_arrow, size=sz_arrow, constriction=None, pos=arrowpos, angle=35)
        ]
      sys.stderr.write("sz_arrow = %.3f wd_arrow = %3f sz_tic = %.3f\n" % (sz_arrow, wd_arrow, sz_tic))
    
    sty_tic = sty_tic + [ pyx.style.linewidth(wd_line), ]
    if dp != None: sty_tic.append(pyx.trafo.translate(dp[0], dp[1]))
    c.stroke(pyx.path.line(a[0], a[1], b[0], b[1]), sty_tic)

def plot_link(c, oph, clr, wd):
  for k in range(path.nelems(oph)):
    omvk = path.elem(oph, k)
    move.plot_layer(c, omvk, None, clr, wd, False, 0, 0)

  return

def get_perp_dir(m, mv0, mv1):
  # Returns the direction from trace {mv0} towards trace{mv1}
  # at the point {m}, assumed to be the midpoint of a contact
  # between them.

  sys.stderr.write("m = ( %.3f %.3f )\n" % ( m[0], m[1],))
  assert hacks.is_point(m)
  assert mv0 != mv1, "both sides on same move?"
  a = [None,None]
  for i in range(2):
    mvi = (mv0,mv1)[i]
    p0i, p1i = move.endpoints(mvi)
    r = min(1, max(0, rn.pos_on_line(p0i, p1i, m))) # Nearest rel pos in move to {m}
    a[i] = rn.mix(1-r, p0i, r, p1i)
  assert a[0] != a[1]
  sys.stderr.write("a = ( %.3f %.3f ) ( %.3f %.3f )\n" % ( a[0][0], a[0][1], a[1][0], a[1][1],))
  u, da = rn.dir(rn.sub(a[1],a[0]))
  return u
  # ----------------------------------------------------------------------
