#!/usr/bin/env python3
# coding: utf-8

import numpy as np
import xarray as xr
import pandas as pd
from netCDF4 import Dataset, MFDataset
import matplotlib.pyplot as plt
import glob
from datetime import datetime, timedelta
from scipy import stats
from sklearn.metrics import mean_squared_error
import calendar

import sys, warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from pandas.core.common import SettingWithCopyWarning
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
# warnings.filterwarnings("ignore", category=SettingWithCopyWarning)

# dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/FULLSET/'
# ufs_dir = '/homes/eseo8/python/UFS/FLUXNET/subset_subpt/DD/data/'
dir = '../data/FULLSET/'
ufs_dir = '../data/UFS/'

site_list = sorted(glob.glob(dir+'*_DD_QC.csv'))

pts = ["P8a"]

######################################################################
f_start = 0 ; f_stop = len(site_list)     # These files/sites [206 stations]
######################################################################

target_vars = ["SSM"]

flux_vars = ["SWC_F_MDS_1"]

ufs_vars = ["SOILW_0M0D1mbelowground"]

years = [x for x in range(2012,2014)]
init_sub_dates = ["0101","0401","0701","1001"]
init_dates = [None] * len(years) * len(init_sub_dates)

cnt = 0
for yr in years:
    for id in init_sub_dates:
        init_dates[cnt] = str(yr)+id
        cnt = cnt + 1

print(init_dates)
comp = dict(zlib=True, complevel=1)

obs = np.empty([len(flux_vars),len(init_dates),35,len(site_list)])
lon = np.empty([len(site_list)])
lat = np.empty([len(site_list)])
site = [None] * len(site_list)
igbp = [None] * len(site_list)
obs[:] = np.nan
lon[:] = np.nan
lat[:] = np.nan

# open FluxTowers
for ff,f in enumerate(site_list[f_start:f_stop]):
    df1 = pd.read_csv(f,na_values=-9999)
    f_siteid = f.split("/")[3].split("DD")[0]
    f2 = sorted(glob.glob(dir+'/'+f_siteid+'site.csv'))
    df2 = pd.read_csv(f2[0],sep=",")
    site[ff] = df2['Site ID'].item()
    lon[ff] = df2['Longitude'].item()
    lat[ff] = df2['Latitude'].item()
    igbp[ff] = df2['IGBP type'].item()
    
    for i,id in enumerate(init_dates):
        time = str(pd.to_datetime(id))[:10]
        tloc = df1.loc[df1["time"] == time].index.item()
        #print(tloc, time)        
        for v,var in enumerate(flux_vars):
        
            obs[v,i,:,ff] = df1[var][tloc:tloc+35]
        del time
        del tloc
        

lon = np.where(lon<0.,lon+360.,lon) # lon=[-180~180] --> [0~360]
print("station lon=({}~{}), lat=({}~{})".format(lon.min(),lon.max(),lat.min(),lat.max()))

# open UFS prototypes
for pp,p in enumerate(pts):
    for v,vars in enumerate(ufs_vars):
        for i,id in enumerate(init_dates):
            vfile = sorted(glob.glob(ufs_dir+p+'_'+id+'_'+vars+'.nc4'))
            if i == 0:
                vdata1 = xr.open_dataset(vfile[0])[vars]
            else:
                data = xr.open_dataset(vfile[0])[vars]
                vdata1 = xr.concat([vdata1, data], dim="init_dates")
                del data
            del vfile
        
        if v == 0:
            vdata2 = vdata1
            time = vdata2.time
        else:
            vv1 = vdata1.assign_coords({"time": (time)}) # time coordinate difference in temperature variables 
            vdata2 = xr.concat([vdata2, vv1], dim="variables")
            del vv1
        del vdata1
    
    if pp == 0:
        vdata3 = vdata2
    else:
        vdata3 = xr.concat([vdata3, vdata2], dim="prototypes")
    del vdata2

ufs = vdata3.copy(deep=True)
del vdata3
print(ufs)
