#!/usr/bin/env python3
# coding: utf-8

import numpy as np
import xarray as xr
import pandas as pd
from netCDF4 import Dataset, MFDataset
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import glob
from datetime import datetime, timedelta
from scipy import stats
from sklearn.metrics import mean_squared_error
import calendar

import sys, warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from pandas.core.common import SettingWithCopyWarning
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
# warnings.filterwarnings("ignore", category=SettingWithCopyWarning)

# dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/FULLSET/'
# ufs_dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/data/'
dir = '../data/SCAN/'
ufs_dir = '../data/UFS/SCAN/'

site_list = sorted(glob.glob(dir+'*.csv'))

pts = ["P7", "P8a"]

######################################################################
f_start = 0 ; f_stop = len(site_list)     # These files/sites [206 stations]
######################################################################

target_vars = ["SSM"]
scan_vars = ["sm_5cm"]
ufs_vars = ["SOILW_0M0D1mbelowground"]

years = [x for x in range(2012,2014)]
init_sub_dates = ["0101","0401","0701","1001"]
init_dates = [None] * len(years) * len(init_sub_dates)

cnt = 0
for yr in years:
    for id in init_sub_dates:
        init_dates[cnt] = str(yr)+id
        cnt = cnt + 1

#print(init_dates)
comp = dict(zlib=True, complevel=1)

# defined array for ufs
ufs = np.empty([len(pts),len(scan_vars),len(init_dates),35,len(site_list)])
ufs[:] = np.nan

obs = np.empty([len(scan_vars),len(init_dates),35,len(site_list)])
lon = np.empty([len(site_list)])
lat = np.empty([len(site_list)])
site = [None] * len(site_list)
obs[:] = np.nan
lon[:] = np.nan
lat[:] = np.nan

# open soil moisture network SSM
for ff,f in enumerate(site_list[f_start:f_stop]):
    df1 = pd.read_csv(f,sep=",",na_values=-9999.0)
    df2 = df1.head(1)
    site[ff] = df2['Station ID'].item()
    lon[ff] = df2['lon'].item()
    lat[ff] = df2['lat'].item()
    
    for i,id in enumerate(init_dates):
       
        yr = float(id[0:4])
        mon = float(id[4:6])
        dd = float(id[7:8]) 
        tloc = df1.loc[(df1["year"] == yr) & (df1["month"] == mon) & (df1["day"] == dd)].index.item()
        for v,var in enumerate(scan_vars):
        
            obs[v,i,:,ff] = df1[var][tloc:tloc+35]
        #print(ff, f)
        #print(i, id)
        #print(v, var)
        del tloc
lon = np.where(lon<0.,lon+360.,lon) # lon=[-180~180] --> [0~360]
print("station lon=({}~{}), lat=({}~{})".format(lon.min(),lon.max(),lat.min(),lat.max()))

# open UFS prototypes
for pp,p in enumerate(pts):
    for v,vars in enumerate(ufs_vars):
        for i,id in enumerate(init_dates):
            vfile = sorted(glob.glob(ufs_dir+p+'_'+id+'_'+vars+'.nc4'))
            vdata1 = xr.open_dataset(vfile[0])[vars]
            ufs[pp,v,i,:,:]= vdata1

# compute nbias, nrmse, rmse, bias, rmse, corr 
bias = np.empty([len(pts),len(scan_vars),len(init_dates),70,len(site_list)])
bias[:] = np.nan
samp_var = np.empty([len(init_sub_dates),len(site_list)])
samp_var[:] = np.nan

for i,id in enumerate(init_sub_dates):
    for ff in range(len(lon)):
        cnt = 0.
        for v,var in enumerate(target_vars): 
            obs2 = obs[v,i::4,:,ff].flatten()
            ufs2 = ufs[0,v,i::4,:,ff].flatten()
            mm = (~np.isnan(obs2)) & (~np.isnan(ufs2))
            if mm.sum() > 45:
                cnt = cnt+1
            del obs2
            del ufs2
        samp_var[i,ff] = cnt     
#print(samp_var)
#exit()

for i,id in enumerate(init_sub_dates):
    print(i,id)
    for v,var in enumerate(target_vars):
        for pp,p in enumerate(pts):
            
            cnt = 0.
            for ff in range(len(lon)):
                obs2 = obs[v,i::4,:,ff].flatten()
                ufs2 = ufs[pp,v,i::4,:,ff].flatten()
                bias[pp,v,i,:,ff] =  ufs[pp,v,i::4,:,ff].flatten() - obs[v,i::4,:,ff].flatten()
bias1 = np.nanmean(bias[0,:,:,:,:])
median1 = np.nanmedian(bias[0,:,:,:,:])

bias2 = np.nanmean(bias[1,:,:,:,:])
median2 = np.nanmedian(bias[1,:,:,:,:])

limit=0.4

plt.figure(figsize=(12, 8))
plt.hist(x= bias[0,0,:,:,:].ravel(),  range=[-limit, limit], bins=51, facecolor='blue', align='mid', alpha=0.5, label='pt7')
plt.hist(x= bias[1,0,:,:,:].ravel(),  range=[-limit, limit], bins=51, facecolor='orange', align='mid', alpha=0.5, label='pt8a')
plt.ylabel('Probability density',fontsize=12)
plt.xlabel('Bias (ufs-SCAN)',fontsize=12)
plt.title('Bias histogram for 8 ICs during 2012-2013', fontsize=16)
plt.axvline(x = bias1, color = 'black', label = 'Pt7 Mean')
plt.axvline(x = median1, color = 'lime', label = 'Pt7 Median')
plt.axvline(x = bias2, color = 'black', linestyle='dashed', label = 'Pt8a Mean')
plt.axvline(x = median2, color = 'lime', linestyle='dashed', label = 'Pt8a Median')
plt.legend(loc='upper right')
#plt.show()
plt.savefig('SCANsurfaceSoilMoisture_histogram_overall.png')
plt.close()


#for ii,id in enumerate(init_sub_dates):
    #bias2 = np.nanmean(bias[:,:,ii,:])
    #rmse2 = np.nanmean(rmse[:,:,ii,:])   
    #corr2 = np.nanmean(corr[:,:,ii,:])
    #nbias2 = np.nanmean(nbias[:,:,ii,:])
    #nrmse2 = np.nanmean(nrmse[:,:,ii,:])
    #print(bias2, rmse2, corr2, nbias2, nrmse2)
exit()
