#!/usr/bin/env python3
# coding: utf-8

import numpy as np
import xarray as xr
import pandas as pd
from netCDF4 import Dataset, MFDataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.colors as colors
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import glob
from datetime import datetime, timedelta
from scipy import stats
from sklearn.metrics import mean_squared_error
import calendar


import sys, warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from pandas.core.common import SettingWithCopyWarning
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
# warnings.filterwarnings("ignore", category=SettingWithCopyWarning)

# dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/FULLSET/'
# ufs_dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/data/'
dir = '../data/CRN/'
ufs_dir = '../data/UFS/CRN/'

site_list = sorted(glob.glob(dir+'*.csv'))

pts = ["P7", "P8a"]

######################################################################
f_start = 0 ; f_stop = len(site_list)     # These files/sites [206 stations]
######################################################################

target_vars = ["SSM"]
scan_vars = ["sm_5cm"]
ufs_vars = ["SOILW_0M0D1mbelowground"]

years = [x for x in range(2012,2014)]
init_sub_dates = ["0101","0401","0701","1001"]
init_dates = [None] * len(years) * len(init_sub_dates)

cnt = 0
for yr in years:
    for id in init_sub_dates:
        init_dates[cnt] = str(yr)+id
        cnt = cnt + 1

#print(init_dates)
comp = dict(zlib=True, complevel=1)

# defined array for ufs
ufs = np.empty([len(pts),len(scan_vars),len(init_dates),35,len(site_list)])
ufs[:] = np.nan

obs = np.empty([len(scan_vars),len(init_dates),35,len(site_list)])
lon = np.empty([len(site_list)])
lat = np.empty([len(site_list)])
site = [None] * len(site_list)
obs[:] = np.nan
lon[:] = np.nan
lat[:] = np.nan

# open soil moisture network SSM
for ff,f in enumerate(site_list[f_start:f_stop]):
    df1 = pd.read_csv(f,sep=",",na_values=-9999.0)
    df2 = df1.head(1)
    site[ff] = df2['StationID'].item()
    lon[ff] = df2['lon'].item()
    lat[ff] = df2['lat'].item()
    
    for i,id in enumerate(init_dates):
       
        yr = float(id[0:4])
        mon = float(id[4:6])
        dd = float(id[7:8]) 
        tloc = df1.loc[(df1["Year"] == yr) & (df1["Month"] == mon) & (df1["Day"] == dd)].index.item()
        for v,var in enumerate(scan_vars):
        
            obs[v,i,:,ff] = df1[var][tloc:tloc+35]
        #print(ff, f)
        #print(i, id)
        #print(v, var)
        del tloc
lon = np.where(lon<0.,lon+360.,lon) # lon=[-180~180] --> [0~360]
print("station lon=({}~{}), lat=({}~{})".format(lon.min(),lon.max(),lat.min(),lat.max()))

# open UFS prototypes
for pp,p in enumerate(pts):
    for v,vars in enumerate(ufs_vars):
        for i,id in enumerate(init_dates):
            vfile = sorted(glob.glob(ufs_dir+p+'_'+id+'_'+vars+'.nc4'))
            vdata1 = xr.open_dataset(vfile[0])[vars]
            ufs[pp,v,i,:,:]= vdata1

# compute nbias, nrmse, rmse, bias, rmse, corr 
bias = np.empty([len(pts),len(scan_vars),len(init_sub_dates),len(site_list)])
bias[:] = np.nan
samp_var = np.empty([len(init_sub_dates),len(site_list)])
samp_var[:] = np.nan

for i,id in enumerate(init_sub_dates):
    for ff in range(len(lon)):
        cnt = 0.
        for v,var in enumerate(target_vars): 
            obs2 = obs[v,i::4,:,ff].flatten()
            ufs2 = ufs[0,v,i::4,:,ff].flatten()
            mm = (~np.isnan(obs2)) & (~np.isnan(ufs2))
            if mm.sum() > 30:
                cnt = cnt+1
            del obs2
            del ufs2
        samp_var[i,ff] = cnt     

for i,id in enumerate(init_sub_dates):
    print(i,id)
    for v,var in enumerate(target_vars):
        for pp,p in enumerate(pts):
            
            cnt = 0.
            for ff in range(len(lon)):
                obs2 = obs[v,i::4,:,ff].flatten()
                ufs2 = ufs[pp,v,i::4,:,ff].flatten()
                mm = (~np.isnan(obs2)) & (~np.isnan(ufs2))
                if mm.sum() > 30:
                    if np.std( obs2[mm] ) != 0.:
                        bias[pp,v,i,ff] = np.mean(ufs2[mm]) - np.mean(obs2[mm])

lon = np.where(lon>180.,lon-360.,lon) # lon=[0~360] --> [-180~180]
# make spatial bias map 
# Calculate data statistics
# --------------------------
network = "CRN"
ptype = "pt7"
initial = "0701"
obarray = bias[0,0,2,:].ravel()

stdev = np.nanstd(obarray)  # Standard deviation
omean = np.nanmean(obarray) # Mean of the data
datmi = np.nanmin(obarray)  # Min of the data
datma = np.nanmax(obarray)  # Max of the data
datcont = np.ma.count(obarray)

# get the domain to plot
lonl = min(lon)
lonr = max(lon)
latb = min(lat)
latu = max(lat)

# Plotting begins here
# Set colorbar type and min/max for colorbar
# --------------------------------------------
if np.nanmin(obarray) < 0:
  cmax = 0.25
  cmin = -0.25
  cmap = 'RdBu'
else:
  cmax = omean+stdev
  cmin = np.maximum(omean-stdev, 0.0)
  cmap = 'RdBu'
  cmap = 'viridis'
  cmap = 'jet'

color_map = plt.cm.get_cmap(cmap)
reversed_color_map = color_map.reversed()

# Set plot variable unit
# -----------------------
units = '-'

# Initialize the plot pointing to the projection
# ------------------------------------------------
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=-100))

# Plot grid lines
# ----------------
gl = ax.gridlines(crs=ccrs.PlateCarree(central_longitude=0), draw_labels=True,linewidth=1, color='gray', alpha=0.5, linestyle='-')
gl.top_labels = False
gl.xlabel_style = {'size': 10, 'color': 'black'}
gl.ylabel_style = {'size': 10, 'color': 'black'}

ax.add_feature(cfeature.COASTLINE)
#ax.add_feature(cfeature.LAKES)

# Create a feature for States/Admin 1 regions at 1:50m from Natural Earth
states_provinces = cfeature.NaturalEarthFeature(
    category='cultural',
    name='admin_1_states_provinces_lines',
    scale='50m',
    facecolor='none')
ax.add_feature(states_provinces, edgecolor='black')
ax.add_feature(cfeature.BORDERS)

# Get scatter data
# ------------------
sc = ax.scatter(lon, lat,
                c=obarray, s=25, linewidth=1, 
                transform=ccrs.PlateCarree(), cmap=cmap, vmin=cmin, vmax = cmax, norm=None, antialiased=True, edgecolors='black')

# Plot colorbar
# --------------
cbar = plt.colorbar(sc, ax=ax, orientation="horizontal", pad=.1, fraction=0.06,)
cbar.ax.set_ylabel(units, fontsize=10)

# Plot CONUS spatial map
# --------------
#ax.set_extent([lonl, lonr, latb, latu])
ax.set_extent([-125, -60, 25, 50])

# Add figure labels
# ------------------
vtitle = " Soil Moisture Bias ("+ptype+"-"+network+")"
ax.set_title("init: "+initial+vtitle, pad=20)

text = f"Total Count:{datcont:0.0f}, Max/Min/Mean/Std: {datma:0.3f}/{datmi:0.3f}/{omean:0.3f}/{stdev:0.3f} {units}"
print(text)
ax.text(0.1, -0.14, text, transform=ax.transAxes, va='bottom', fontsize=6.2)

dpi=150
#plt.tight_layout()

# show plot
# -----------
fname = "spatialMap_"+network+"_"+initial+"_"+ptype+"_soilMoistureBias.png"
plt.savefig(fname, dpi=dpi)
exit()
