#! /usr/bin/python3 -t
# Last edited on 2021-02-18 22:51:20 by jstolfi

MODULE_NAME = "rn"
MODULE_DESC = "Linear algebra operations on numeric vectors"
MODULE_VERS = "1.0"

MODULE_COPYRIGHT = "Copyright © 2009 State University of Campinas"

MODULE_INFO = \
  "A library module to perform linear algebra operations on numeric vectors.\n" \
  "\n" \
  "  Input vectors can be tuples or lists.  Output vectors will be tuples.\n"

import sys
import copy
from math import sqrt,sin,cos

def add(x,y) :
  "Vector sum of {x+y}."
  n = len(x);
  assert len(y) == n, "incompatible {x,y} lenghts";
  r = [None]*n;
  for i in range(n) :
    r[i] = x[i] + y[i];
  return tuple(r);
  # ----------------------------------------------------------------------

def sub(x,y) :
  "Vector difference {x-y}."
  n = len(x);
  assert len(y) == n, "incompatible {x,y} lenghts";
  r = [None]*n;
  for i in range(n) :
    r[i] = x[i] - y[i];
  return tuple(r);
  # ----------------------------------------------------------------------

def scale(s,x) :
  "Scals the vector {x} by {s}, which may be a float or a vector."
  n = len(x);
  r = [None]*n;
  if type(s) is tuple or type(s) is list:
    assert len(s) == n, "incompatible {x,s} lengths"
    for i in range(n) :
      r[i] = s[i]*x[i];
  elif type(s) is int or type(s) is float:
    for i in range(n) :
      r[i] = s*x[i];
  else:
    assert False, "invalid scale {s}"
  return tuple(r);
  # ----------------------------------------------------------------------

def mix(s,x,t,y) :
  "Returns {s*x+t*y}."
  n = len(x);
  assert len(y) == n, "incompatible {x,y} lenghts";
  r = [None]*n;
  for i in range(n) :
    r[i] = s*x[i] + t*y[i];
  return tuple(r);
  # ----------------------------------------------------------------------

def mix3(s,x,t,y,u,z) :
  "Returns {s*x+t*y+u*z}."
  n = len(x);
  assert (len(y) == n) and (len(z) == n), "incompatible {x,y,z} lenghts";
  r = [None]*n;
  for i in range(n) :
    r[i] = s*x[i] + t*y[i] + u*z[i];
  return tuple(r);
  # ----------------------------------------------------------------------

def mix4(s,x,t,y,u,z,v,o) :
  "Returns {s*x+t*y+u*z+v*o}."
  n = len(x);
  assert (len(y) == n) and (len(z) == n) and (len(o) == n), "incompatible {x,y,z,o} lenghts";
  r = [None]*n;
  for i in range(n) :
    r[i] = s*x[i] + t*y[i] + u*z[i] + v*o[i];
  return tuple(r);
  # ----------------------------------------------------------------------

def dir(x) :
  "Vector {x} normalized to unit Euclidean length. Also returns the original norm."
  n = len(x);
  e = norm(x) + 1.0e-290;
  r = [None]*n;
  for i in range(n) :
    r[i] = x[i]/e;
  return tuple(r), e;
  # ----------------------------------------------------------------------

def dot(x,y) :
  "Scalar product of {x} by {y}."
  n = len(x);
  assert len(y) == n, "incompatible {x,y} lenghts";
  s = 0;
  for i in range(n) :
    s += x[i] * y[i];
  return s;
  # ----------------------------------------------------------------------

def norm_sqr(x) :
  "Square of Euclidean norm of {x}."
  n = len(x);
  s = 0;
  for i in range(n) :
    xi=x[i]; s += xi*xi;
  return s;
  # ----------------------------------------------------------------------

def norm(x) :
  "Euclidean norm of {x}."
  return sqrt(norm_sqr(x));
  # ----------------------------------------------------------------------

def dist(x,y) :
  "Euclidean distance between {x} and {y}."
  return norm(sub(x,y));
  # ----------------------------------------------------------------------

def cross2d(x,y) :
  "Cross product of two vectors in R^2 (a real number)."
  assert len(x) == 2, "{x} must be a point of R^2";
  assert len(y) == 2, "{y} must be a point of R^2";
  return x[0]*y[1]-x[1]*y[0];
  # ----------------------------------------------------------------------

def cross3d(x,y) :
  "Cross product of two vectors in R^3 (a vector of R^3)."
  assert len(x) == 3, "{x} must be a point of R^3";
  assert len(y) == 3, "{y} must be a point of R^3";
  return ( x[1]*y[2]-x[2]*y[1], x[2]*y[0]-x[0]*y[2], x[0]*y[1]-x[1]*y[0] );
  # ----------------------------------------------------------------------

def rotate2(x,ang) :
  "Rotates the first two coords of {x} by {ang} radians around the origin"
  assert len(x) >= 2, "{x} must have at least 2 coords";
  c = cos(ang);
  s = sin(ang);
  return ( c*x[0] - s*x[1], s*x[0] + c*x[1] ) + tuple(x[2:]);
  # ----------------------------------------------------------------------

# BOXES

# An {n}-dimensional /box/ is a subset of {R^n} that is the Cartesian
# product of {n} intervals. It is represented here as a pair (2-tuple)
# or points {(plo,phi)}, whose coordinates are respectively the low and
# high ends of those intervals.  A valid box must have {plo[i] <= phi[i]}
# for all{i}.

def box_from_point(x):
  "Retuns the box that contains the single point {x}."
  return (tuple(x), tuple(x))
  
def box_include_point(B,x):
  # Returns the smallest box that includes the box {B} and the point {x}.
  # However, if {x} is {None}, returns {B}; else, if {B} is {None}, returns 
  # a box with the single point {x}.
  if x == None:
    return B
  elif B == None:
    return box_from_point(x)
  else:
    n = len(x);
    assert (len(B[0]) == n) and (len(B[1]) == n), "incompatible {B[i],x} lenghts";
    plo = [None]*n;
    phi = [None]*n
    for i in range(n):
      plo[i] = min(B[0][i], x[i])
      phi[i] = max(B[1][i], x[i])
    return (tuple(plo), tuple(phi),)
  # ----------------------------------------------------------------------
  
def box_join(A,B):
  # Returns the smallest box that includes the two boxes {A} and {B}.
  # However, if either box is {None}, returns the other box.
  if A == None:
    return B
  elif B == None:
    return A
  else:
    n = len(A[0]);
    assert (len(A[1]) == n), "incompatible {A[0],A[1]} lenghts";
    assert (len(B[0]) == n) and (len(B[1]) == n), "incompatible {A[i],B[i]} lenghts";
    plo = [None]*n;
    phi = [None]*n
    for i in range(n):
      plo[i] = min(A[0][i], B[0][i])
      phi[i] = max(A[1][i], B[1][i])
    return (tuple(plo), tuple(phi))
  # ----------------------------------------------------------------------
 
def box_expand(B,dlo,dhi):
  # Returns a copy of box {B} with the lower and upper corners
  # displaced by the specified amounts -- outwards if positive, 
  # inwards if negative. 
  #
  # The procedure fails if the displacements would result in
  # an invalid box. If {B} is {None}, returns {None}.
  if B == None:
    return B
  else:
    n = len(B[0]);
    assert (len(B[1]) == n), "incompatible {B[0],B[1]} lenghts";
    assert (len(dlo) == n) and (len(dhi) == n), "incompatible {dlo,dhi} lenghts";
    plo = [None]*n;
    phi = [None]*n
    for i in range(n):
      plo[i] = B[0][i] - dlo[i]
      phi[i] = B[1][i] + dhi[i]
      assert plo[i] <= phi[i], "invalid displacements for this box"
    return (tuple(plo), tuple(phi))
  # ----------------------------------------------------------------------
 
# ----------------------------------------------------------------------
