##########################################
# TD_zt_plot.py
'''
FROM NETCDF OUTPUT
CREATE (z,t) plot
of solution  
'''
#DANIEL DAUHAJRE UCLA 2016
#########################################
################################
import os
import sys
import numpy as np
import scipy as sp
from pylab import *
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from matplotlib.ticker import NullFormatter, MultipleLocator, FormatStrFormatter


import custom_cmap as cm_cust
plt.ion()
################################


###################################
# LOAD OUTPUT (NETCDF)
###################################

path_output = '../nc_files/'
path,dirs,file_names_out = os.walk(path_output).__next__()
nc_out = Dataset(path_output + file_names_out[0],'r')

######################################

print('#######################################################################')
print('                  Solution: ' + nc_out.title)
print('#########################################################################')


base_name = input('ENTER IN BASE NAME FOR FIGURE TO GO AS SUFFIX OF run_ID: ')



run_ID = nc_out.title
dt     = nc_out.dt

nt_total = len(nc_out.dimensions['time'])

# GET TOTAL TIME STEPS AND LENGTH
# TO DETERMINE INDICES OF DAYS
# CHOSEN BY USER BELOW
tend_sec_total = nt_total * dt
tvec_sec_total = np.arange(0,tend_sec_total,dt)
tvec_days_toal = tvec_sec_total / (86400.)
num_days_total = (int(nt_total/dt)) + 1
len_day = nt_total / num_days_total
tsteps = range(nt_total)


################################################
# PROMPT USER TO SEE IF THEY WANT TO ZOOM
# IN ON SPECIFIC DAY(S) OF SIMULATION
################################################
user_input = True
while user_input:
      plot_time_choice = input('Do you want to plot for a specific time period (i.e., day) (Y/N)? ')
      if plot_time_choice == 'Y' or plot_time_choice == 'N':
         user_input = False

if plot_time_choice == 'Y':
   time_period_days = input('Enter in days in list for [start_day stop_day] (SEPARATED BY A SPACE!!) that you would like to zoom in on (IF ONLY ONE DAY, TYPE IN SAME DAY AS START AND STOP): ')
   tperiod_list = map(int,time_period_days.split())
   start_day = tperiod_list[0]
   end_day   = tperiod_list[1]

   #IF START DAY SAME AS STOP DAY
   ndays_user = end_day - start_day

   tstep_start = (start_day-1) * len_day
   tstep_end   = tstep_start + (len_day * (ndays_user+1))
  
   # FOR DECLARING ARRAYS IN NEXT BLOCK OF CODE
   nt = len(np.arange(tstep_start,tstep_end))
else:
    tstep_start = 0
    tstep_end   = nt_total
    nt = nt_total



##############################
# LOAD VARIABLES
###############################
z_r    = nc_out.variables['z_r'][:]
z_w    = nc_out.variables['z_w'][:]
v_nt   = nc_out.variables['v'][tstep_start:tstep_end,:]
Kv_nt  = nc_out.variables['Kv'][tstep_start:tstep_end,:]

[nt,N] = v_nt.shape

################################

#####################################
# CALCULATE VERTICAL GRADIENTS OF v
###################################
v_z_nt   = np.zeros([nt,N-1])
v_zz_nt  = np.zeros([nt,N-2])
v_zzz_nt = np.zeros([nt,N-3])


for n in range(nt):
    if (n+1)%10 == 0:
        print('Calculating vertical gradients: ' + str(round((float(n+1)/nt)*100)) + '% complete')
    for k in range(N-1):
        v_z_nt[n,k] = (v_nt[n,k+1] - v_nt[n,k]) / (z_r[k+1] - z_r[k])
    for k in range(N-2):
        v_zz_nt[n,k] = (v_z_nt[n,k+1] - v_z_nt[n,k]) / (z_w[1:-1][k+1] - z_w[1:-1][k])
    for k in range(N-3):
        v_zzz_nt[n,k] = (v_zz_nt[n,k+1] - v_zz_nt[n,k]) / (z_r[1:-1][k+1] - z_r[1:-1][k])


dv_dt_nt = np.zeros([nt,N])
for k in range(N):
    dv_dt_nt[:,k] = np.gradient(v_nt[:,k]) / dt


#####################################
# 	PLOTTING
#####################################
###############################
# FORM TIME VECTOR
##############################
tend_sec = nt * dt
tvec_sec = np.arange(0,tend_sec,dt)
tvec_days = tvec_sec / (86400.)
num_days = (int(nt/dt)) + 1
len_day = nt / num_days
tsteps = range(nt)



# FUNCTION TO PLOT VERTICAL
# LINES DELINIATING DAY(S)
num_days = int((nt*dt)/86400)
def plot_day_lines():
    for n in range(1,num_days):
        plt.axvline((n*len_day),color='lime',linewidth=1.5)

######################################

# CREATE MESHGRID FOR CONTOUR PLOTS
[t,z]       = np.mgrid[0:v_nt.shape[0],0:v_nt.shape[1]]
[t_vz,z_vz] = np.mgrid[0:v_z_nt.shape[0],0:v_z_nt.shape[1]]
[t_w_p,z_w_p]   = np.mgrid[0:Kv_nt.shape[0],0:Kv_nt.shape[1]]


# MAKE CUSTOM COLORBAR BASED ON RGB VALUES
cols_cust_jet = [(255,0,255),(0,0,255),(0,255,255),(105,188,105),(255,255,0),(255,165,0),(255,0,0)]
cm_cust_jet = cm_cust.make_cmap(cols_cust_jet,bit=True)
cmap_Kv = cm.Spectral_r
cmap_v  = cm_cust_jet

def make_color_levs(var,nticks):
    """ 
    Function to create array
    of values for contour lines
    based on min/max of variable
    that is symmetric about zero 
    (good for variables that are positive/
    negative)
    
    var --> variable
    nticks --> number of points in levels array (i.e. number of contours)
    """
    var_max = np.max(abs(var))
    var_min = -var_max
    dvar = (var_max - var_min) / nticks
    return np.arange(var_min,var_max+dvar,dvar)



levs_v     = make_color_levs(v_nt,100)
levs_v_z   = make_color_levs(v_z_nt,100)
levs_dv_dt = make_color_levs(dv_dt_nt,100) 

Kv_max = np.nanmax(abs(Kv_nt))
Kv_min = 0 
nlevs_Kv = 100
dKv = (Kv_max - Kv_min) / nlevs_Kv
levs_Kv = np.arange(Kv_min,Kv_max+dKv,dKv)


title_font = 30
axis_font = 22
cbar_tick_size = 20
axis_tick_size = 20

z_step = 50


fig = plt.figure(figsize=[15,15])
plt.suptitle(run_ID,fontsize=title_font)
plt.subplot(4,1,1)
plot_day_lines()
plt.yticks(z_w_p[0,::z_step],z_r[::z_step])
#plt.contour(t_w_p,z_w_p,Kv_nt,15,colors='k',linewidths=2)
plt.contourf(t_w_p,z_w_p,Kv_nt,levels=levs_Kv,cmap=cmap_Kv,extend='both')
c1 = plt.colorbar()
c1.set_label(r'$\kappa_v \; \left(m^2s^{-1}\right)$',fontsize=axis_font)
ax = plt.gca()
ax.xaxis.set_major_formatter(NullFormatter())
plt.ylabel(r'$z(m)$',fontsize=axis_font)

plt.subplot(4,1,2)
plot_day_lines()
plt.yticks(z[0,::z_step],z_r[::z_step])
#plt.contour(t,z,v_nt,15,colors='k',linewidthts=2)
plt.contourf(t,z,v_nt,levels=levs_v,cmap=cmap_v,extend='both')
c1 = plt.colorbar()
c1.set_label(r'$v \; \left(ms^{-1}\right)$',fontsize=axis_font)
plt.ylabel(r'$z(m)$',fontsize=axis_font)
ax = plt.gca()
ax.xaxis.set_major_formatter(NullFormatter())


plt.subplot(4,1,3)
plot_day_lines()
plt.yticks(z_vz[0,::z_step],z_r[::z_step])
#plt.contour(t_vz,z_vz,v_z_nt,15,colors='k',linewidhts=2)
plt.contourf(t_vz,z_vz,v_z_nt,levels=levs_v_z,cmap=cmap_v,extend='both')
c1 = plt.colorbar()
c1.set_label(r'$\partial v/ \partial z \; \left(s^{-1}\right)$',fontsize=axis_font)
plt.ylabel(r'$z(m)$',fontsize=axis_font)
ax = plt.gca()
ax.xaxis.set_major_formatter(NullFormatter())


plt.subplot(4,1,4)
plot_day_lines()
plt.yticks(z[0,::z_step],z_r[::z_step])
#plt.contour(t,z,dv_dt_nt,15,colors='k',linewidthts=2)
plt.contourf(t,z,dv_dt_nt,levels=levs_dv_dt,cmap=cmap_v,extend='both')
c1 = plt.colorbar()
c1.set_label(r'$\partial v/ \partial t \; \left(ms^{-2}\right)$',fontsize=axis_font)
plt.ylabel(r'$z(m)$',fontsize=axis_font)
ax = plt.gca()
ax.xaxis.set_major_formatter(NullFormatter())
plt.xlabel('Time',fontsize=axis_font)




plt.savefig(run_ID + '_' + base_name + '_zt_solns',bbox_inches='tight')
print('Saved figure as: ' + run_ID + '_' + base_name + '_zt_solns')





