import rn
import path
import contact
  
def build(OPHS, SCS, z, mp_jump):
  RSS = create_scan_line(OPHS, SCS)
  
  Ba, Bz = make_snakes_paths(RSS, z)
  
  bph = path.concat(Ba, True, mp_jump)
  if Bz != []: 
    bph_z = path.concat(Bz, True, mp_jump)
    bph = path.concat([bph, path.rev(bph_z)], True, mp_jump)  

  CTS_lo, CTS, CTS_hi = get_raster_contacts(RSS, len(RSS))
  
  CTS_lo.sort(key = lambda ct: contact.pmid(ct)[0])
  TCVS_lo = [ contact.path_tcovs(bph, ct)[1] for ct in CTS_lo ]

  CTS_hi.sort(key = lambda ct: contact.pmid(ct)[0])
  TCVS_hi = [ contact.path_tcovs(path.rev(bph), ct)[0] for ct in CTS_hi ]

  if not valid_path(bph, CTS):
    return None, None, None, None, None
  
  return bph, CTS_lo, TCVS_lo, CTS_hi, TCVS_hi

###

def create_scan_line(OPHS, SCS):
  nsc = len(SCS)
  RSS = [None]*nsc

  for i in range(nsc):
    RSS[i] = []
    for j in SCS[i]:
      RSS[i].append(OPHS[j])
      
  return RSS

###

def make_snakes_paths(RSS, z):
  # {RSS} is a list of rasters that comprise a snake, sorted by {Y}.
  # Returns the canonical snake path formed by those rasters.
  # It always starts with the bottom raster. When {z} parameter is 0,
  # the path starts with the bottom raster, traversed left to right.
  # When {z} parameter is 1, the path starts with the bottom raster, 
  # traversed right to left.

  nsc = len(RSS) # Number of scan-lines in band.
  assert nsc >= 1

  Ba = []
  Bz = []

  RSS_used = [ [False]*len(RSS[i]) for i in range(nsc) ]

  ia = 0 ; iz = len(RSS)-1
  ja = 0 if z == 0 else (len(RSS[ia])-1)
  jz = (len(RSS[iz])-1) if z == 0 else 0
  za = z ; zz = 1 - z

  if ia == iz and ja == jz: iz = None

  q_end_a = None; q_end_z = None

  while True:
    if ia != None:
      RSS_used[ia][ja] = True
      raster_a = RSS[ia][ja] if za == 0 else path.rev(RSS[ia][ja])
      Ba.append(raster_a)
      q_end_a = path.pfin(raster_a)
    
    if iz != None:
      RSS_used[iz][jz] = True
      raster_z = RSS[iz][jz] if zz == 0 else path.rev(RSS[iz][jz])
      Bz.append(raster_z)
      q_end_z = path.pfin(raster_z)

    ia, ja, za, iz, jz, zz = find_next_raster(RSS, RSS_used, nsc, q_end_a, q_end_z)
    if ia == None and iz == None: break
  
  return Ba, Bz

###

def find_next_raster(RSS, RSS_used, nsc, q_end_a, q_end_z):
  # Checks if exists at least one raster line not used.
  if all(False not in r for r in RSS_used): return [ None ] * 6
  
  next_rasters_a = [ None ]*nsc
  next_rasters_z = [ None ]*nsc

  for i in range(nsc):
    if False in RSS_used[i]:
      dp_min_a = None ; j_min_a = None ; z_min_a = None
      dp_min_z = None ; j_min_z = None ; z_min_z = None

      for j in range(len(RSS[i])):
        if not RSS_used[i][j]:
          raster = RSS[i][j]

          p_ini = path.pini(raster)
          p_fin = path.pfin(raster)

          dp_a, z_a = calculate_distance(q_end_a, p_ini, p_fin)
          dp_z, z_z = calculate_distance(q_end_z, p_ini, p_fin)
          
          next_a = False ; next_z = False

          if   (dp_a != None and dp_z == None) and (dp_min_a == None or dp_a < dp_min_a): next_a = True
          elif (dp_z != None and dp_a == None) and (dp_min_z == None or dp_z < dp_min_z): next_z = True
          elif (dp_a != None and dp_z != None):
            if dp_a > dp_z:
              if (dp_min_z == None or dp_z < dp_min_z): next_z = True
              elif (dp_min_a == None or dp_a < dp_min_a): next_a = True
            
            else:
              if (dp_min_a == None or dp_a < dp_min_a): next_a = True
              elif (dp_min_z == None or dp_z < dp_min_z): next_z = True
          
          if next_a: dp_min_a, j_min_a, z_min_a = dp_a, j, z_a
          elif next_z: dp_min_z, j_min_z, z_min_z = dp_z, j, z_z
      
      if dp_min_a != None: next_rasters_a[i] = [dp_min_a, j_min_a, z_min_a]
      if dp_min_z != None: next_rasters_z[i] = [dp_min_z, j_min_z, z_min_z]

  i_a = None ; j_a = None ; z_a = None ; dp_min_a = None
  i_z = None ; j_z = None ; z_z = None ; dp_min_z = None  

  for index in range(len(next_rasters_a)):
    nr_a = next_rasters_a[index]
    nr_z = next_rasters_z[index]

    if nr_a != None and (dp_min_a == None or dp_min_a > nr_a[0]):
      i_a, dp_min_a, j_a, z_a = index, nr_a[0], nr_a[1], nr_a[2]

    if nr_z != None and (dp_min_z == None or dp_min_z > nr_z[0]):
      i_z, dp_min_z, j_z, z_z = index, nr_z[0], nr_z[1], nr_z[2]
    
  return i_a, j_a, z_a, i_z, j_z, z_z


###

def calculate_distance(q_end, p_ini, p_fin):
  if q_end == None:
    return None, None
    
  dp_ini = rn.dist(q_end, p_ini)
  dp_fin = rn.dist(q_end, p_fin)

  dp = min(dp_ini, dp_fin)
  z = 0 if dp == dp_ini else 1

  return dp, z

###

def get_raster_contacts(RSS, nsc):
  CTS_lo = []
  for oph in RSS[0]:
    for ct in path.get_contacts(oph, 1):
      CTS_lo.append(ct)
  
  CTS_hi = []
  for oph in RSS[-1]:
    for ct in path.get_contacts(oph, 0):
      CTS_hi.append(ct)

  CTS = []
  for isc in range(1, nsc):
    for oph in RSS[isc]:
      for ct in path.get_contacts(oph, 1):
        CTS.append(ct)

  return CTS_lo, CTS, CTS_hi

###

def valid_path(bph, CTS):
  for ct in CTS:
    delta = contact.tcool_limit(ct)
    if contact.tcool(bph, ct) > delta: 
      return False
  return True