#! /usr/bin/python3
# Last edited on 2024-04-01 16:17:05 by stolfi

import sys
import math
import copy

import jspnm
from jspnm import quantize,write_ppm_image,write_pgm_image
      
def make_weave_pic(qx, qy, nb, threadpix) :
  # {qx} X size of image (squarelets)
  # {qy} Y size of image (squarelets)
  # {nb} size of a squarelet (pixels)
  # {threadpix(ru,rv,su,sv,tk,st)} thread-rendering procedure.
  nx = qx*nb;  # X size of complete image (pixels)
  ny = qy*nb;  # Y size of complete image (pixels)
  pic = make_array(nx, ny)
  sys.stderr.write('rendering [')
  for iy in range((ny+1)//2) :
    sys.stderr.write('.')
    for ix in range((nx+1)//2) :
      # Paint top thread:
      pix = render_weave_layer(ix,iy,nb,nx,ny,1,threadpix);
      if (pix == None) :
        # Paint bottom thread:
        pix = render_weave_layer(ix,iy,nb,nx,ny,0,threadpix);
      if (pix == None) :
        # Paint background:
        pix = render_weave_layer(ix,iy,nb,nx,ny,-1,threadpix);
      assert pix != None
      pic[iy][ix] = pix;
      pic[iy][nx-1-ix] = pix;
      pic[ny-1-iy][ix] = pix;
      pic[ny-1-iy][nx-1-ix] = pix;
  sys.stderr.write(']\n')
  return pic
  # ----------------------------------------------------------------------

def render_weave_layer(ix,iy,nb,nx,ny,st,threadpix) :
  # {ix,iy} pixel indices in first quarter of whole image.
  # {nb} size of a squarelet (pixels)
  # {nx,ny} dimensions of whole image (pixels).
  # {st} weave layer to evaluate (1 = top, 0 = bottom, -1 = background).
  # {threadpix(ru,rv,su,sv,tk,st)} thread-rendering procedure.
  # Returns {None} if pixel is not on the layer.
  assert ix <= nx - 1 - ix
  assert iy <= ny - 1 - iy
  if (st == -1) :
    return threadpix(0,0,1,1,-1,0)
  else :
    # Compute indices {kx,ky} of current squarelet:
    kx = ix // nb # X index of curr squarelet.
    ky = iy // nb # Y index of curr squarelet.
    # Compute number {t} of tile band containing {jx,jy}:
    t = (kx + ky) // 3 # Index of curr tile band.
    # Compute which kind {tk} of thread is on top in this band:
    tt = t % 2 # 0 = X-thread on top, 1 = Y-thread on top.
    # Compute which thread is to be painted:
    tk = (tt + st + 1) % 2;
    # Compute upper corner {bx,by} and pixel size {sx,sy} of tile:
    if (tk == 0) :
      # Tile of the X-oriented thread:
      bx = nb*(3*t - ky)
      by = nb*ky
      sx = 3*nb;
      sy = nb
    else :
      # Tile of the Y-oriented thread:
      bx = nb*kx
      by = nb*(3*t - kx)
      sx = nb;
      sy = 3*nb;
    # Adjust tile size to account for mid-axis mirroring: 
    if (2*(bx + sx) >= nx) :
      sx = nx - 2*bx
    if (2*(by + sy) >= ny) :
      sy = ny - 2*by
    # Compute thread-oriented pixel coords {ru,rv} rel to tile, and tile size {su,sv}:
    if (tk == 0) :
      # X-oriented thread, U=X, V=Y:
      ru = ix - bx;
      rv = iy - by;
      su = sx;
      sv = sy;
    else :
      # Y-oriented thread, U=Y, V=X:
      ru = iy - by;
      rv = ix - bx;
      su = sy;
      sv = sx;
    if ((ru < 0) or (ru >= su) or (rv < 0) or (rv >= sv)) :
      sys.stderr.write(" pixel (%4d, %4d) t = %d --> (%4d, %4d) in (%4d,%4d)\n" % (ix,iy,t,ru,rv,su,sv))
    assert (ru >= 0) and (ru < su)
    assert (rv >= 0) and (rv < sv)
    return threadpix(ru,rv,su,sv,tk,st)
  # ----------------------------------------------------------------------

def color_field(ix,iy,sx,sy,tk,st) :
  # {ix,iy} pixel indices relative to a basic weave tile.
  # {sx,sy} dimensions of tile (pixels).
  # {tk} which kind of thread is to be painted.
  # {st} layer where the thread is (0 = bottom, 1 = top).
  # Assumes that the thread runs in the X direction.
  # If {tk} is -1, returns the background color.
  # Otherwise returns {None} if {tk} is not on the thread. 
  if (tk == -1) :
    # Return background color:
    return [ 0.1, 0.1, 0.1]
  on, hz = weave_params(ix,iy,sx,sy,st)
  if (not on) :
    # Pixel is not on layer:
    return None
  else :
    # Compute color from rel height {hz} and tread type {tk}.
    # Compute shading factor {fc}:
    fc = 0.7 + 0.3*hz
    # Return proper color:
    return [ 
      (tk*0.8700 + (1-tk)*0.1500)*fc +  0.1000*(1-fc),
      (tk*0.8500 + (1-tk)*0.2500)*fc +  0.1000*(1-fc),
      (tk*0.7500 + (1-tk)*0.3800)*fc +  0.1000*(1-fc)
    ]
  # ----------------------------------------------------------------------

def height_field(ix,iy,sx,sy,tk,st) :
  # {ix,iy} pixel indices relative to a basic tile.
  # {sx,sy} dimensions of tile (pixels).
  # {tk} which kind of thread is to be painted.
  # {st} layer where the thread is (0 = bottom, 1 = top).
  # Assumes that the thread runs in the X direction.
  if (tk == -1) :
    # Return background height:
    return 0.0
  on, hz = weave_params(ix,iy,sx,sy,st)
  if (not on) :
    # Pixel is not on layer:
    return None
  else :
    # Pixel is on a thread.
    # Compute actual height in {[0_1]}:
    return 0.500 + 0.450 * hz
  # ----------------------------------------------------------------------

def weave_params(ix,iy,sx,sy,st) :
  # {ix,iy} pixel indices relative to a basic tile.
  # {sx,sy} dimensions of tile (pixels).
  # {st} layer where the thread is (0 = bottom, 1 = top).
  # Assumes that the thread runs in the X direction.
  # Returns pair {(on,hz)} where
  # {on} is a boolean that tells whether the pixel is over the thread.
  # {hz} rel height of thread's surface ([-1.0 _ +1.0]), or 0 if not {on}.
  mb = sy // 10 # Margin pixels.
  if ((iy < mb) or (iy >= sy - mb)) :
    # Pixel is not over thread:
    on = 0;
    hz = 0;
  else :
    # Pixel is over thread:
    on = 1;
    # Thread's surface height relative to thread axis:
    gw = nb - 2*mb;
    gz0 = (iy - mb + 0.5)/gw;
    gz1 = (nb - mb - 0.5 - iy)/gw;
    gz = math.sqrt(4*gz0*gz1)
    # gz = 1 - (1 - gz)*(1 - gz); # Flatten it.
    # X factor -- axis height:
    fw = 3*nb;
    fz0 = (ix + 0.5)/fw;
    fz1 = (3*nb - 0.5 - ix)/fw;
    fz = 4*fz0*fz1
    fz = 1 - (1 - fz)*(1 - fz); # Flatten it.
    # Negate axis height {fz} if thread is in bottom layer:
    if (st != 1) :
      fz = - fz;
    # Height is sum of axis and surface:
    hz = 0.500*(gz + fz);
  return (on, hz)
  # ----------------------------------------------------------------------

def make_array(nx,ny) :
  sys.stderr.write("allocating array\n");
  a = ny*[None];
  for iy in range(ny) :
    a[iy] = copy.copy(nx*[None])
  return a;

# MAIN

# Get arguments:
name = sys.argv[1]; # Name of output files sans extension.
qx = int(sys.argv[2]); # Number of squarelets in X.
qy = int(sys.argv[3]); # Number of squarelets in Y.
nb = int(sys.argv[4]); # Side of each squarelet (pixels).

nx = qx*nb
ny = qy*nb

sys.stderr.write('creating color image\n')
pic_color = make_weave_pic(qx, qy, nb, color_field)
write_ppm_image("out/" + name + "-tx.ppm", pic_color, nx, ny)

sys.stderr.write('creating height map\n')
pic_height = make_weave_pic(qx, qy, nb, height_field)
write_pgm_image("out/" + name + "-ht.pgm", pic_height, nx, ny)

