##########################################
# TD_snaps_plot.py
'''
FROM NETCDF OUTPUT
CREATE SNAPSHOTS OF
SOLUTIONS IN TIME
FOR VISUALIZATION
'''
#DANIEL DAUHAJRE UCLA 2016
#########################################
################################
import os
import sys
import numpy as np
import scipy as sp
from pylab import *
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from matplotlib.ticker import NullFormatter, MultipleLocator, FormatStrFormatter

plt.ion()
################################


###################################
# LOAD OUTPUT (NETCDF)
###################################

path_output = '../nc_files/'
path,dirs,file_names_out = os.walk(path_output).__next__()
nc_out = Dataset(path_output + file_names_out[0],'r')

######################################

print('#######################################################################')
print('                  Solution: ' + nc_out.title)
print('#########################################################################')


base_name = input('ENTER IN BASE NAME FOR FIGURE TO GO AS SUFFIX OF run_ID: ')



run_ID = nc_out.title
dt     = nc_out.dt

nt_total = len(nc_out.dimensions['time'])

# GET TOTAL TIME STEPS AND LENGTH
# TO DETERMINE INDICES OF DAYS
# CHOSEN BY USER BELOW
tend_sec_total = nt_total * dt
tvec_sec_total = np.arange(0,tend_sec_total,dt)
tvec_days_toal = tvec_sec_total / (86400.)
num_days_total = (int(nt_total/dt)) + 1
len_day = nt_total / num_days_total
tsteps = range(nt_total)


################################################
# PROMPT USER TO SEE IF THEY WANT TO ZOOM
# IN ON SPECIFIC DAY(S) OF SIMULATION
################################################
user_input = True
while user_input:
      plot_time_choice = input('Do you want to plot for a specific time period (i.e., day) (Y/N)? ')
      if plot_time_choice == 'Y' or plot_time_choice == 'N':
         user_input = False

if plot_time_choice == 'Y':
   time_period_days = input('Enter in days in list for [start_day stop_day] (SEPARATED BY A SPACE!!) that you would like to zoom in on (IF ONLY ONE DAY, TYPE IN SAME DAY AS START AND STOP): ')
   tperiod_list = map(int,time_period_days.split())
   start_day = tperiod_list[0]
   end_day   = tperiod_list[1]

   #IF START DAY SAME AS STOP DAY
   ndays_user = end_day - start_day

   tstep_start = (start_day-1) * len_day
   tstep_end   = tstep_start + (len_day * (ndays_user+1))
  
   # FOR DECLARING ARRAYS IN NEXT BLOCK OF CODE
   nt = len(np.arange(tstep_start,tstep_end))
else:
    tstep_start = 0
    tstep_end   = nt_total
    nt = nt_total



##############################
# LOAD VARIABLES
###############################
z_r    = nc_out.variables['z_r'][:]
z_w    = nc_out.variables['z_w'][:]
v_nt   = nc_out.variables['v'][tstep_start:tstep_end,:]
Kv_nt  = nc_out.variables['Kv'][tstep_start:tstep_end,:]

[nt,N] = v_nt.shape

################################

#####################################
# CALCULATE VERTICAL GRADIENTS OF v
###################################
v_z_nt   = np.zeros([nt,N-1])
v_zz_nt  = np.zeros([nt,N-2])
v_zzz_nt = np.zeros([nt,N-3])


for n in range(nt):
    if (n+1)%10 == 0:
        print('Calculating vertical gradients: ' + str(round((float(n+1)/nt)*100)) + '% complete')
    for k in range(N-1):
        v_z_nt[n,k] = (v_nt[n,k+1] - v_nt[n,k]) / (z_r[k+1] - z_r[k])
    for k in range(N-2):
        v_zz_nt[n,k] = (v_z_nt[n,k+1] - v_z_nt[n,k]) / (z_w[1:-1][k+1] - z_w[1:-1][k])
    for k in range(N-3):
        v_zzz_nt[n,k] = (v_zz_nt[n,k+1] - v_zz_nt[n,k]) / (z_r[1:-1][k+1] - z_r[1:-1][k])


#####################################
# 	PLOTTING
#####################################
###############################
# FORM TIME VECTOR
##############################
tend_sec = nt * dt
tvec_sec = np.arange(0,tend_sec,dt)
tvec_days = tvec_sec / (86400.)
num_days = (int(nt/dt)) + 1
len_day = nt / num_days
tsteps = range(nt)



# FUNCTION TO PLOT VERTICAL
# LINES DELINIATING DAY(S)
num_days = int((nt*dt)/86400)
def plot_day_lines():
    for n in range(1,num_days):
        plt.axvline((n*len_day),color='k',linewidth=1.5)


######################################


dir1 = run_ID + '_'+base_name +  '_snaps_nt'
if not os.path.exists(dir1):
   print('Creating directory: ' + dir1)
   os.makedirs(dir1)
print('Moving into directory: ' + dir1)
os.chdir(dir1)

#VERTICAL LEVEL TO PLOT Kv time series
k_p = N/2
max_Kv = np.nanmax(Kv_nt)
Kv_lims = [0,max_Kv]

v_max = np.nanmax(abs(v_nt))
v_lims = [-v_max,v_max]

v_z_max = np.nanmax(abs(v_z_nt))
v_z_lims = [-v_z_max,v_z_max]
v_zz_max = np.nanmax(abs(v_zz_nt))
v_zz_lims = [-v_zz_max,v_zz_max]

v_zzz_max = np.nanmax(abs(v_zzz_nt))
v_zzz_lims = [-v_zzz_max,v_zzz_max]




axis_font = 18
leg_font = 16

plt.ioff()
for n in range(nt):
#for n in range(1):
    plt.figure(figsize=[15,10])
    
    #############################
    # Kv time series
    #############################
    plt.subplot2grid((16,12),(0,1),rowspan=3,colspan=10)
    plot_day_lines()
    plt.plot(tsteps,Kv_nt[:,k_p],linewidth=3,color='k')
    plt.plot(tsteps[n],Kv_nt[n,k_p],'o',color='c',markersize=10)
    plt.ylim(Kv_lims)
    plt.xlim([0,nt])
    plt.ylabel(r'$\kappa_v (m^2s^{-1})$',fontsize=axis_font)
    plt.grid(True)
    plt.title(run_ID)
    ax = plt.gca()
    ax.xaxis.set_major_formatter(NullFormatter())

    #############################
    # v vertical profile
    #############################
    plt.subplot2grid((16,12),(5,1),rowspan=5,colspan=4)
    plt.plot(v_nt[n,:],z_r,color='firebrick',linewidth=3,label=r'$v$')
    plt.plot(np.zeros(N),z_r,color='k',linestyle='--',linewidth=1.5)
    plt.xlim(v_lims)
    plt.xlabel(r'$ms^{-1}$',fontsize=axis_font)
    plt.ylabel(r'$z(m)$',fontsize=axis_font)
    plt.grid(True)
    plt.legend(loc=1,fancybox=True,framealpha=0.005,fontsize=leg_font)


    #############################
    # dv/dz vertical profile
    #############################
    plt.subplot2grid((16,12),(5,6),rowspan=5,colspan=4)
    plt.plot(v_z_nt[n,:],z_w[1:-1],color='darkcyan',linewidth=3,label=r'$v_z$')
    plt.plot(np.zeros(N-1),z_w[1:-1],color='k',linestyle='--',linewidth=1.5)
    plt.xlim(v_z_lims)
    plt.xlabel(r'$s^{-1}$',fontsize=axis_font)
    plt.ylabel(r'$z(m)$',fontsize=axis_font)
    plt.grid(True)
    plt.legend(loc=1,fancybox=True,framealpha=0.005,fontsize=leg_font)


    #############################
    # d^2v/dz^2 vertical profile
    #############################
    plt.subplot2grid((16,12),(11,1),rowspan=5,colspan=4)
    plt.plot(v_zz_nt[n,:],z_r[1:-1],color='orange',linewidth=3,label=r'$v_{zz}$')
    plt.plot(np.zeros(N-2),z_r[1:-1],color='k',linestyle='--',linewidth=1.5)
    plt.xlim(v_zz_lims)
    plt.xlabel(r'$m^{-1} s^{-1}$',fontsize=axis_font)
    plt.ylabel(r'$z(m)$',fontsize=axis_font)
    plt.grid(True)
    plt.legend(loc=1,fancybox=True,framealpha=0.005,fontsize=leg_font)


    #############################
    # d^3v/dz^3 vertical profile
    #############################
    plt.subplot2grid((16,12),(11,6),rowspan=5,colspan=4)
    plt.plot(v_zzz_nt[n,:],z_w[2:-2],color='forestgreen',linewidth=3,label=r'$v_{zzz}$')
    plt.plot(np.zeros(N-3),z_w[2:-2],color='k',linestyle='--',linewidth=1.5)
    plt.xlim(v_zzz_lims)
    plt.xlabel(r'$m^{-2} s^{-1}$',fontsize=axis_font)
    plt.ylabel(r'$z(m)$',fontsize=axis_font)
    plt.grid(True)
    plt.legend(loc=1,fancybox=True,framealpha=0.005,fontsize=leg_font)



    t_str = str(n+1)
    if n+1 <10:
       t_str = '0' + str(n+1)
    plt.savefig(run_ID + '_snaps_' + t_str,bbox_inches='tight')
    print('Saved figure: ' + t_str)
    plt.close()



os.chdir('../')


















































