#! /usr/bin/python3
# Last edited on 2023-12-07 14:46:01 by stolfi

import scipy.io as sio
from math import sqrt, floor, ceil, exp, pi, sin, cos
import sys, os

# The EEG file is divided into /segments/. A segment is a /beat/, the
# consecutive frames during a full beat of a rythmic stimulus (waltz,
# samba, random), typically ~113 frames; or a /gap/ between beats,
# usually between different rythms, with variable number of frames.
#
# During each beat a /stimulus/ was played, which was either a 
# strong clap, a weak clap, or a silence. 
# 
# A /session/ is a set of consecutive beats of the same rythm.
#
# Each session comprises an integer number of /bars/. A bar is three
# beats in a waltz session, four beats in a samba session, and one beat
# in a random session.
#
# Adds two extra electrodes: 'C128', 'C129' and marker channels
# 'RTHM', 'BEAT', and 'STIM'.

def main():

  # ----------------------------------------------------------------------
  # Unpack the EEG data as an array {me} with {nt} columns (frames)
  # and {nc_in} rows (channels):
  
  subj = 1
  topdir = "raw"
  subjdir = f"{topdir}/v{subj:02d}"

  eeg_set = sio.loadmat(subjdir + '/eeg.mat')
  print(eeg_set.keys())

  me = eeg_set.get('data')
  nc_in = len(me)
  ne_in = nc_in # All input channels are electrodes.
  nt = len(me[0])
  sys.stderr.write(f"nc = {nc_in} nt = {nt}\n")

  # ----------------------------------------------------------------------
  # Unpack the stimulus timing data as arrays {mt} and {mk}
  # with {ng} columns and one row each:

  sti_set = sio.loadmat(subjdir + '/sti.mat')
  print(sti_set.keys())

  mt = sti_set.get('triggerstype')
  mtrows = len(mt)
  mtcols = len(mt[0])
  print('mt', mtrows, mtcols)
  assert mtrows == 1
  ng = mtcols
  sys.stderr.write(f"ng = {ng}\n")

  mk = sti_set.get('triggerslatency')
  mkrows = len(mk)
  mkcols = len(mk[0])
  print('mk', mkrows, mkcols)
  assert mkrows == 1
  assert mkcols == mtcols

  # Get the segments with rythm, bar index, beat and stimulus:
  SG = get_segment_data(mt,mk,nt)
  
  # Write all beat data to disk:
  f_beat = open(subjdir + '/sti.txt', 'w')
  for kg in range(len(SG)):
    SGk = SG[kg];
    nt_beat = SGk['tfin'] - SGk['tini'] + 1
    f_beat.write(f"{kg:04d} {SGk['sess']:04d} {SGk['tini']:6d} {SGk['tfin']:6d} {nt_beat:6d} {SGk['rthm']} {SGk['barr']:04d} {SGk['beat']:d} {SGk['stim']}\n")
  f_beat.close()
  
  # Add synthetic electrodes:
  me = add_synthetic_electrodes(me, SG);
  ne_ot = len(me)
  
  # Add marker channels to the {me} array:
  me = add_marker_channels(me, SG)
  nc_ot = len(me)
  
  #  Extract individual sessions: 
  nt_pad = 113;
  SS = get_sessions(SG,nt,nt_pad)
   
  # Write out the sessions:
  for ks in range(len(SS)):
    SSk = SS[ks]
    rthm = SSk['rthm']
    sess = SSk['sess']
    tini = SSk['tini']
    tfin = SSk['tfin']
    write_eeg_session(subjdir, subj, me, ne_ot, rthm, sess, tini, tfin)
  
  # Write out the whole EEG data:
  write_eeg_session(subjdir, subj, me, ne_ot, None, None, 0, nt-1)

  print("Done.")
  
def get_segment_data(mt,mk,nt):
  # Returns a list of dicts, one per segment, each with keys:
  #   'tini', 'tfin' -- indices (0-based) of first and last frames in segment.
  #   'sess' - index (0-based) of session withing same rythm.
  #   'rthm' - either 'V' (waltz), 'S' (samba), 'R' (random), or '-'
  #   'barr' - index (0-based) of bar, sequential over all sessions of same rythm.
  #   'beat' - index (0-based) of beat within bar, 0-2 for 'V', 0-3 for 'S' else 0.
  #   'stim' - index (0-based) of stimulus (0=silence, 1=weak, 2-strong).
  #
  # In the {mt} and {mk} arrays, gaps have simulus codes 'epoc' or 'sil '. 
  # The code 'miss' are omission of a weak beat, 'v0  ' are mandatory silences, 
  # 'v1  ', 'v1a ', 'v1b ' are weak beats, and 'v2  ' is a strong beat.

  ng = len(mt[0]) # Number of stimulus segments.
  assert len(mk[0]) == ng, "array size mismatch"
  nsess = {'-':0, 'V':0, 'R':0, 'S':0} # Count of sessions for each rythm.
  nbarr = {'-':0, 'V':0, 'R':0, 'S':0} # Count of bars for each rythm.
  nbeat = {'-':0, 'V':0, 'R':0, 'S':0} # Count of beats for each rythm.
  nfram = {'-':0, 'V':0, 'R':0, 'S':0} # Count of frames for each rythm.
  tini_prev = 0
  sess_prev = 0
  rthm_prev = '-'
  barr_prev = 0
  beat_prev = 0
  stim_prev = '-'
  SG = [].copy()
  for kg in range(ng):
    
    tini = mk[0][kg]
    diff = tini - tini_prev
    SG.append({'tini': tini_prev, 'tfin': tini-1, 'sess': sess_prev, 'rthm': rthm_prev, 'barr': barr_prev, 'beat': beat_prev, 'stim': stim_prev})
    nfram[rthm_prev] += tini - tini_prev;
    
    rthm = get_rythm(tini);
    
    # Get the session number {sess} for this rithm:
    if rthm == '-':
      sess = 0
    elif rthm != rthm_prev:
      sess = nsess[rthm];
      nsess[rthm] += 1
    else:
      sess = sess_prev
      
    # Obtain the stimulus code {stim} from the stimulus data file:
    stim = mt[0][kg] # Should be a singleton list.
    assert len(stim) == 1
    stim = stim[0]
    assert isinstance(stim, str)
    stim = stim.strip()
    if stim == 'miss' or stim == 'v0':
      stim = '0' 
    elif stim == 'v1' or stim == 'v1a' or stim == 'v1b':
      stim = '1'
    elif stim == 'v2':
      stim = '2'
    elif stim == 'epoc':
      stim = 'E'
    elif stim == 'sil':
      stim = '-'
    else:
      assert False, 'unrecognized stim'
      
    # Define the phase {beat} based on rythm and strong beat position:
    if rthm != rthm_prev:
      # Changed rythm section.
      sys.stderr.write(f"{tini:6d} rythm {rthm_prev} changed to {rthm}\n")
      beat_prev = -1
    if rthm == 'R' or rthm == '-':
      beat = 0
    else:
      period = 3 if rthm == 'V' else 4
      if stim == '2':
        assert beat_prev == -1 or beat_prev + 1 == period, f"{tini:6d} strong beat out of sync {beat_prev:d}"
        beat = 0
      else:
        assert beat_prev != -1, "rythm does not start with strong beat"
        beat = (beat_prev + 1) % period
        assert beat != 0, "expected a strong beat"
    
    # Get the bar index {barr}:
    if rthm == '-':
      barr = 0
    else:
      nbeat[rthm] = nbeat[rthm] + 1
      if beat == 0:
        barr = nbarr[rthm];
        nbarr[rthm] = barr + 1
        
    if rthm == '-' and (stim != '-' and stim != 'E'):
      sys.stderr.write(f"{tini:6d} ignored non-silence {stim} ({rthm} beat {beat:d})\n")

    assert tini > tini_prev
    tini_prev = tini
    sess_prev = sess
    rthm_prev = rthm
    barr_prev = barr
    beat_prev = beat
    stim_prev = stim

  # Last segment:
  diff = nt - tini_prev
  SG.append({'tini': tini_prev, 'tfin': nt-1, 'sess': sess_prev, 'rthm': rthm_prev, 'barr': barr_prev, 'beat': beat_prev, 'stim': stim_prev})
  nfram[rthm_prev] += nt - tini_prev;
  
  for rthm in '-', 'V', 'R', 'S':
    sys.stderr.write(f"  {rthm}: {nsess[rthm]:4d} sessions {nbarr[rthm]:4d} bars {nbeat[rthm]:5d} beats {nfram[rthm]:8d} frames\n")
    
  return SG

def get_rythm(tini):
  # Define the type of rythm {rythm} ('V', 'S', 'R', 'E', '-')
  # of the segment that starts at frame {tini}:
  if \
    (tini >= 7462 and tini < 22313) or \
    (tini >= 26247 and tini < 41098) or \
    (tini >= 45032 and tini < 59882) or \
    (tini >= 316144 and tini < 330995) or \
    (tini >= 334929 and tini < 349779) or \
    (tini >= 353714 and tini < 368564):
    # Waltzer rythm:
    rythm = 'V'
  elif \
    (tini >= 66960 and tini < 81811) or \
    (tini >= 85745 and tini < 100595) or \
    (tini >= 104530 and tini < 119380) or \
    (tini >= 253746 and tini < 268597) or \
    (tini >= 272531 and tini < 287382) or \
    (tini >= 291316 and tini < 306166):
    # Random indep stimuli:
    rythm = 'R'
  elif \
    (tini >= 129487 and tini < 144338) or \
    (tini >= 148272 and tini < 163122) or \
    (tini >= 167057 and tini < 181907) or \
    (tini >= 192157 and tini < 207007) or \
    (tini >= 210942 and tini < 225792) or \
    (tini >= 229726 and tini < 244577):
    # Samba rythm:
    rythm = 'S'
  else:
    # Pauses and such:
    rythm = '-'
  return rythm

def add_synthetic_electrodes(me, SG):
  # The {me} parameter should be the EEG data array.
  # The {SG} parameter should be the segment data table.
  # Returns a copy of {me} with the same rows plus two extra rows
  # for synthetic electrodes 'C128' and 'C129'.
  #   'C129' (index 128) Gaussian bell pulses synchronized with the rythm
  #   'C130' (index 129) a sinusoidalsignal with increasing frequency.
  # The frequency of the 'C129' signal increases linearly from 0
  # to 250 Hz with shooth attack and release.
  
  nc_in = len(me);
  ne_in = nc_in  # Assumes all input channels are electrodes;
  
  # Allocate the nw electrodes:
  ne_ot = ne_in;
  ie_puls = ne_ot; ne_ot += 1
  ie_chrp = ne_ot; ne_ot += 1
  
  nc_ot = ne_ot; # All output channels are electrodes too.
  nt = len(me[0]);
  
  me_new = [None]*nc_ot;
  for ic in range(nc_ot):
    if ic < ne_in:
      me_new[ic] = me[ic]
    elif ic == ie_puls:
      me_new[ic] = [-50]*nt;
    elif ic == ie_chrp:
      me_new[ic] = [+50]*nt;
    else:
      assert False
  assert len(me_new) == nc_ot
  
  ng = len(SG)
  kb = 0; # Beat index in rythm.
  rthm_prev = '-'  # Rythm of previous segment.
  tini_prev = -1   # First frame of previous session.
  tfin_prev = -1   # Last frame of previous session.
  for kg in range(ng):
    if kg > 0: assert SG[kg]['tini'] == SG[kg-1]['tfin'] + 1
    SGk = SG[kg]
    rthm = SGk['rthm']
    beat = SGk['beat']
    tini = SGk['tini']
    tfin = SGk['tfin']
    
    # Define the C129 and C130 channels:
    rthm_num = {'-':0, 'R':1, 'V':2, 'S':3}[rthm]
    nb = [0, 1, 3, 4][rthm_num] # Number of beats per bar.
    
    if rthm == '-' or rthm != rthm_prev:
      # Reset beat count and remember firts frame of this rythm.
      kb = 0
    
    # Narrow pulse:
    if rthm == 'R':
      splat_narrow = ((kb % 3) == 1)
      splat_broad = False
    elif rthm == 'V':
      splat = ((kb % 9) == 1)
      splat_broad = ((kb % 9) == 5)
    elif rthm == 'S':
      splat_narrow = ((kb % 8) == 1)
      splat_broad = ((kb % 8) == 4)
    elif rthm == '-':
      splat_narrow = False
      splat_broad = False
    else:
      assert False

    if splat_narrow:
      ctr = (tini+tfin)/2
      dev = 0.30*(tfin-tini+1)
      splat_pulse(me_new, ie_puls, ctr, dev)
      
    if splat_broad:
      ctr = tfin + 0.5
      dev = nb*0.30*(tfin-tini+1)
      splat_pulse(me_new, ie_puls, ctr, dev)
      
    if rthm_prev != '-' and rthm_prev != rthm:
      # Last beat of session; set the chirp channel:
      splat_chirp(me_new, ie_chrp, tini_prev, tfin)
      
    # Prepare for next segment:
    kb += 1
    if rthm != rthm_prev:
      tini_prev = tini
    rthm_prev = rthm
    tfin_prev = tfin
      
  return me_new
  
def splat_pulse(me, ie, ctr, dev):
  # Adds to channel {ie} a pulse with center index {ctr} (maybe fractional)
  # and deviation {dev}.
  
  nc = len(me)
  nt = len(me[0])
  assert ie >= 0 and ie < nc, "invalid channel"
  
  amp = 40;
  ilo = max(int(floor(ctr - 4.5*dev)), 0)
  ihi = min(int(ceil(ctr + 4.5*dev)), nt - 1);
  for it in range(ilo, ihi+1):
    z = (it - ctr)/dev;
    me[ie][it] += amp*exp(-0.5*z*z)
    
def splat_chirp(me, ie, tini, tfin):
  # Adds to channel {ie} a chirp whose frequency (cycles/frame) increases linearly
  # from 0 to 0.25 between {tini} and {tfin}, with smooth shoulders.
  
  nt = len(me[0])
  nc = len(me)
  assert ie >= 0 and ie < nc, "invalid channel"
  assert tini >= 0 and tini < tfin and tfin < nt, "invalid frame range"

  wtot = tfin-tini+1
  wsho = 5*113 # Width of smooth transition.
  assert wtot > 2*wsho, "session too short for chirp"
  phase = 0;
  amp = 40; # Amplitude of chirp.
  fmax = 0.33333 # Max frequency (cycles/sample).
  for it in range(tini, tfin+1):
    dt = min(it - tini, tfin - it)
    if dt < wsho:
      sho = 0.5*(1 - cos(pi*dt/wsho))
    else:
      sho = 1.0;
    me[ie][it] += amp*sho*cos(phase)
    dphase = 2*fmax*pi*(it - tini)/(tfin - tini);
    phase += dphase

def add_marker_channels(me, SG):
  # Returns a copy of {me} with the same rows plus three extra rows
  # for marker channels 'RTHM', 'BEAT', 'STIM'.
  #   'RTHM' is 0 for gaps, 1 for random, 2 for valtz, 3 for samba.
  #   'BEAT' is 0 for gaps, 1-3 for valtz, 1-4 for samba, 1 for random.
  #   'STIM' is 0 for gaps, 1 for skipped beat, 2 for weak beat, 3 for strong beat.
  # The {SG} parameter should be the segment data table.
  #
  # The 'BEAT' and 'STIM' are set to zero in the last two frames
  # to make the boundaries more visible even after subsampling by 2.
  
  nc_in = len(me);
  ne_in = nc_in  # Assumes all input channels are electrodes;
  nc_ot = nc_in + 3;
  nt = len(me[0]);
  
  me_new = [None]*nc_ot;
  for ic in range(nc_ot):
    if ic < ne_in:
      me_new[ic] = me[ic]
    else:
      me_new[ic] = [None]*nt;
  assert len(me_new) == nc_ot
  
  ng = len(SG)
  for kg in range(ng):
    if kg > 0: assert SG[kg]['tini'] == SG[kg-1]['tfin'] + 1
    SGk = SG[kg]
    rthm = SGk['rthm']
    beat = SGk['beat']
    stim = SGk['stim']
    
    # Define the RTHM, BEAT, STIM channel:
    rthm_num = {'-':0, 'R':1, 'V':2, 'S':3}[rthm]
    nb = [0, 1, 3, 4][rthm_num] # Number of beats per bar.
    if rthm_num != 0:
      beat_num = SG[kg]['beat'] + 1; assert beat_num >= 1 and beat_num <= nb
      stim_num = int(SG[kg]['stim']) + 1; assert stim_num >= 1 and stim_num <= 3
    else:
      beat_num = 0
      stim_num = 0
      
    tini = SGk['tini']
    tfin = SGk['tfin']
    for it in range(tini,tfin+1):
      me_new[nc_in+0][it] = rthm_num
      if it < tfin-1:
        me_new[nc_in+1][it] = beat_num
        me_new[nc_in+2][it] = stim_num
      else:
        me_new[nc_in+1][it] = 0
        me_new[nc_in+2][it] = 0
      
  # Check for complete coverage:
  for ic in range(nc_in,nc_ot):
    sys.stderr.write(f"checking marker channel {ic} of {nc_ot}...\n")
    for it in range(nt):
      assert me_new[ic][it] != None
      
  return me_new
 
def get_sessions(SG,nt,nt_pad):
  # Receives a list {SG} of segments.
  # Returns a list {SS} of sessions with data to be written,
  # expanded by {nt_pad}.  Each element is a dict with fields
  # 'tini', 'tfin', 'sess', 'rthm'.
  
  ng = len(SG);
  SS = [].copy()
  rthm_prev = '-' # Rythm of previous segment.
  tini_prev = -1  # First frame of first segment of current session.
  tfin_prev = -1  # End frame of previous segment.
  sess_prev = -1  # Session index of previous segment.
  for kg in range(ng):
    SGk = SG[kg];
    rthm = SGk['rthm']
    tini = SGk['tini']
    tfin = SGk['tfin']
    sess = SGk['sess']
    if rthm != '-':
      if rthm != rthm_prev:
        # Start of a new session:
        assert rthm_prev == '-', f"segment {tini}..{tfin} change of rythm without gap"
        tini_sess = tini - nt_pad;
        assert tini_sess >= 0, f"segment {tini}..{tfin} initial gap too short"
      else:
        # Continuation of same session:
        assert tini_sess >= 0 
        assert sess == sess_prev, f"segment {tini}..{tfin} incongruent session {sess_prev} {sess}"
    else:
      if rthm_prev != '-':
        # End of a session:
        assert tini_sess >= 0 and tini_sess <= tfin
        tfin_sess = tfin + nt_pad
        assert tfin_sess < nt, f"segment {tini}..{tfin} final gap too short"
        SS.append({'tini': tini_sess, 'tfin': tfin_sess, 'sess': sess_prev, 'rthm': rthm_prev})
        tini_sess = -1
        tfin_sess = -1
    tini_prev = tini
    tfin_prev = tfin
    rthm_prev = rthm
    sess_prev = sess
    
  # File shoud end with a gap segment:
  assert tini_sess == -1
    
  ns = len(SS)
  sys.stderr.write(f"condensed {ng} segments to {ns} sessions\n")
  return SS;

def write_eeg_session(dir, subj, me, ne, rthm, sess, tini, tfin):
  # Writes columns {tini..tfin} of {me} to file,
  # subtracting fromeach electrode the average of non-gap samples.
  
  nc = len(me)
  nt_full = len(me[0])
  nt = tfin - tini + 1;

  # Open file:
  if rthm == None:
    assert sess == None
    fname = dir + '/eeg.txt'
  else:
    sessdir = dir +  f"/{rthm}_{sess:04d}"
    if os.path.exists(sessdir):
      assert os.path.isdir(sessdir), f"{sessdir} is not a directory"
    else:
      os.mkdir(sessdir)
    fname = sessdir + "/eeg.txt"
  sys.stderr.write(f"writing frames {tini}..{tfin} (out of {nt_full}) to file '{fname}'\n")
  
  # Max, min, sum, and sum of squares per electrode, excluding gaps:
  vmax = [ -1e100 ]*ne
  vmin = [ +1e100 ]*ne
  sum1 = [ 0 ]*ne
  sum2 = [ 0 ]*ne

  nsmp = 0; # Number of samples in average.

  for it in range(tini,tfin+1):
    rthm_num = me[ne][it] # Numeric rthm code.
    if rthm_num != 0:
      nsmp += 1
      for ie in range(ne):
        el = me[ie][it]
        if el < vmin[ie]: vmin[ie] = el
        if el > vmax[ie]: vmax[ie] = el
        sum1[ie] = sum1[ie] + el
        sum2[ie] = sum2[ie] + el*el
  
  sys.stderr.write(f"statistics for {nsmp} frames with stimumus (out of {nt}):\n")
  avg = [ None ]*ne
  rms = [ None ]*ne
  for ie in range(ne):
    avg[ie] = sum1[ie]/nsmp
    rms[ie] = sqrt(sum2[ie]/nsmp)
    sys.stderr.write(f"{ie:3d} {vmin[ie]:+10.2f} _ {vmax[ie]:+10.2f} {avg[ie]:+10.2f} {rms[ie]:10.2f}\n")

  ff = open(fname, 'w');
  write_eeg_header(ff, subj, nt_full, tini, tfin, nc, ne, rthm)

  for it in range(tini,tfin+1):
    write_frame(ff, me, it, ne, avg)

  ff.close()

def write_eeg_header(ff, subj, nt_full, tini, tfin, nc, ne, rthm):
  fsmp = 250;
  nt = tfin - tini + 1
  ff.write("nt = %d\n" % nt)
  ff.write("nc = %d\n" % nc)
  ff.write("channels = ")
  for ie in range(ne):
    ff.write(" C%03d" % (ie+1))
  ff.write(" RTHM")
  ff.write(" BEAT")
  ff.write(" STIM")
  ff.write("\n")
  # ff.write("kfmax = %d\n" % -1) # Not a power spectrum file.
  ff.write("ne = %d\n" % ne)
  if rthm != '-':
    ff.write("type = %s\n" % rthm) # Session type (rythm)
  # ff.write("component = %d\n" % NULL) # Not a component file.
  ff.write("fsmp = %.10f\n" % fsmp)
    
  # ff.write("tdeg = %d\n" % -1) # No polynomial shift.
  # ff.write("tkeep = %d\n" % 1) # No polynomial shift.
  # ff.write("band = %.2f %.2f  %.2f %.2f %d\n" % (0.0, 0.0, inf, inf, 0))  # No filtering.
  # ff.write("rebase_wt = %d\n" % NULL) # No rebasing (Cz-referenced).
  
  if tini != 0 or tfin != nt_full-1:
    ff.write(f"orig.file = v{subj:02d}/eeg.txt\n")
    ff.write(f"orig.nt = {nt_full}\n");
    ff.write(f"orig.sample_range = {tini} {tfin}\n");
    ff.write(f"orig.subject = {subj}\n");

def write_frame(ff, me, it, ne, avg):
  # Assumes that the marker channels have been added to {me}.
  # Writes the frame {me[0..nc-1][it]} subtracting {avg[ie]} from each electrode {ie}.
  nc = len(me)
  for ic in range(nc):
    el = me[ic][it]
    if type(el) == str:
      ff.write(" %s" % el)
      sys.stderr.write(f"!! string sample: {it:6d} {ic:3d} {el}\n")
    elif ic >= ne:
      # Assume it is a marker channel:
      ff.write(" %d" % int(el+0.000001))
    else:
      # Electrode potential:
      ff.write(" %+7.1f" % (el - avg[ic]))
  ff.write("\n")

main()
sys.exit(0)