#!/usr/bin/env python3
# coding: utf-8

import numpy as np
import xarray as xr
import pandas as pd
from netCDF4 import Dataset, MFDataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.colors as colors
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import glob
from datetime import datetime, timedelta
from scipy import stats
from sklearn.metrics import mean_squared_error
import calendar

import sys, warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from pandas.core.common import SettingWithCopyWarning
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
# warnings.filterwarnings("ignore", category=SettingWithCopyWarning)

# dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/FULLSET/'
# ufs_dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/data/'
dir = '../data/OKM/'
ufs_dir = '../data/UFS/OKM/70cm/'

network = "OKM"
site_list = sorted(glob.glob(dir+'*.csv'))
pts = ["P7", "P8a"]

######################################################################
f_start = 0 ; f_stop = len(site_list)     # These files/sites [206 stations]
######################################################################

target_vars = ["70cmSM"]
scan_vars = ["sm_25cm"]
ufs_vars = ["SOILW_0D4M1mbelowground"]

str_years = ["2012","2013"]
years = [x for x in range(2012,2014)]
init_sub_dates = ["0101","0401","0701","1001"]
init_dates = [None] * len(years) * len(init_sub_dates)

cnt = 0
for yr in years:
    for id in init_sub_dates:
        init_dates[cnt] = str(yr)+id
        cnt = cnt + 1

#print(init_dates)
comp = dict(zlib=True, complevel=1)

# defined array for ufs
ufs = np.empty([len(pts),len(scan_vars),len(init_dates),35,len(site_list)])
ufs[:] = np.nan

obs = np.empty([len(scan_vars),len(init_dates),35,len(site_list)])
lon = np.empty([len(site_list)])
lat = np.empty([len(site_list)])
site = [None] * len(site_list)
obs[:] = np.nan
lon[:] = np.nan
lat[:] = np.nan

# open soil moisture network for 60cm soil layer
for ff,f in enumerate(site_list[f_start:f_stop]):
    df1 = pd.read_csv(f,sep=",",na_values=-9999.0)
    df2 = df1.head(1)
    site[ff] = df2['Station_ID'].item()
    lon[ff] = df2['lon'].item()
    lat[ff] = df2['lat'].item()
    
    for i,id in enumerate(init_dates):
       
        yr = float(id[0:4])
        mon = float(id[4:6])
        dd = float(id[7:8]) 
        tloc = df1.loc[(df1["year"] == yr) & (df1["month"] == mon) & (df1["day"] == dd)].index.item()
        for v,var in enumerate(scan_vars):
        
            obs[v,i,:,ff] = df1[var][tloc:tloc+35]
        #print(ff, f)
        #print(i, id)
        #print(v, var)
        del tloc

lon = np.where(lon<0.,lon+360.,lon) # lon=[-180~180] --> [0~360]
print("station lon=({}~{}), lat=({}~{})".format(lon.min(),lon.max(),lat.min(),lat.max()))

# open UFS prototypes
for pp,p in enumerate(pts):
    for v,vars in enumerate(ufs_vars):
        for i,id in enumerate(init_dates):
            vfile = sorted(glob.glob(ufs_dir+p+'_'+id+'_'+vars+'.nc4'))
            vdata1 = xr.open_dataset(vfile[0])[vars]
            ufs[pp,v,i,:,:]= vdata1

# compute spatially-averaged ufs and obs varaible
obsAve = np.empty([len(scan_vars),len(init_dates),len(obs[0,0,:,0])])
ufsAve = np.empty([len(pts),len(scan_vars),len(init_dates),len(obs[0,0,:,0])])
obsAve[:] = np.nan
ufsAve[:] = np.nan

for i,id in enumerate(init_dates):
    for v,var in enumerate(target_vars):
        for pp,p in enumerate(pts):
            for t in range(len(obs[0,0,:,0])):
                obs2 = obs[v,i,t,:].flatten()
                ufs2 = ufs[pp,v,i,t,:].flatten()
                mm = (~np.isnan(obs2)) & (~np.isnan(ufs2))
                if mm.sum() > 30:
                    if np.std( obs2[mm] ) != 0.:
                        obsAve[v,i,t] = np.mean(obs2[mm])
                        ufsAve[pp,v,i,t] = np.mean(ufs2[mm])
                        #print(v, obsAve1[v,i,t], obsAve2[v,i,t])


xx = np.empty([len(obs[0,0,:,0])])
for t in range(len(obs[0,0,:,0])):
    xx[t] = t

for i,iy in enumerate(str_years):

    # make time series plots
    plt.figure(figsize=(18, 10))

    ax1 = plt.subplot(2, 2, 1)
    ax1.set_title(" (a) init: "+iy+"0101 ")
    ax1.set_ylim(0.1,0.4)
    ax1.plot(xx,obsAve[0,0+i,:],'k-o',label="OKM25cm")
    ax1.plot(xx,ufsAve[0,0,0+i,:],color="red",label="pt7")
    ax1.plot(xx,ufsAve[1,0,0+i,:],color="blue",label="pt8a")
    ax1.set_ylabel("Soil Moisture [m3/m3]", fontsize=20) 
    ax1.legend()

    ax2 = plt.subplot(2, 2, 2)
    ax2.set_title(" (b) init: "+iy+"0401 ")
    ax2.set_ylim(0.1,0.4)
    ax2.plot(xx,obsAve[0,1+i,:],'k-o',label="OKM25cm")
    ax2.plot(xx,ufsAve[0,0,1+i,:],color="red",label="pt7")
    ax2.plot(xx,ufsAve[1,0,1+i,:],color="blue",label="pt8a")
    ax2.legend()

    ax3 = plt.subplot(2, 2, 3)
    ax3.set_title(" (c) init: "+iy+"0701 ")
    ax3.set_ylim(0.1,0.4)
    ax3.plot(xx,obsAve[0,2+i,:],'k-o',label="OKM25cm") 
    ax3.plot(xx,ufsAve[0,0,2+i,:],color="red",label="pt7")
    ax3.plot(xx,ufsAve[1,0,2+i,:],color="blue",label="pt8a")
    ax3.set_ylabel("Soil Moisture [m3/m3]", fontsize=20)
    ax3.set_xlabel("forecast days", fontsize=20)
    ax3.legend()

    ax4 = plt.subplot(2, 2, 4)
    ax4.set_title(" (d) init: "+iy+"1001 ")
    ax4.set_ylim(0.1,0.4)
    ax4.plot(xx,obsAve[0,3+i,:],'k-o',label="OKM25cm")
    ax4.plot(xx,ufsAve[0,0,3+i,:],color="red",label="pt7")
    ax4.plot(xx,ufsAve[1,0,3+i,:],color="blue",label="pt8a")
    ax4.set_xlabel("forecast days", fontsize=20)
    ax4.legend()

    fname = "spatiallyAveraged_"+network+"_ufs_25cmSM"+iy+"timeseries.png"
    plt.tight_layout()
    plt.savefig(fname)
exit()
