import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:].ravel()
    return MyArray

def ReadNETCDF_2D(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:]
    return MyArray

def Computediff(VariName):
    print(VariName)
    DAFilePath1="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/DA_GHCN_2015to2020/mem000/restarts/vector"
    DAFilePath2="/scratch2/NCEPDEV/stmp3/Youlong.Xia/landDA/cycle_land/C768/DA_imsGHCN_2015to2020/mem000/restarts/vector"
    OpenloopFilePath="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/openloop_ghcnHOFX_2015to2020/mem000/restarts/vector"
    modisFilePath="/scratch2/NCEPDEV/land/data/evaluation/MODIS/albedo/C768/"

    fname = "FourPanels_C768_NorthAmerica_"+VariName+"_timeseries_2020.png"

    RestartDate = "2019-11-01_00-00-00"
    DeltaTime = 1  # day
    SimulationDuration = 210  # days

    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2

#   defined the basin mask
    basin = ReadNETCDF(BasinPath, "basins")

    DA11 = np.empty([len(basin)])
    DA21 = np.empty([len(basin)])
    Openloop1 = np.empty([len(basin)])
    modis1 =  np.empty([len(basin)])
    DA11[:] = np.nan
    DA21[:] = np.nan
    Openloop1[:] = np.nan
    modis1[:] = np.nan 
  
    DA12 = np.empty([len(basin)])
    DA22 = np.empty([len(basin)])
    Openloop2 = np.empty([len(basin)])
    modis2 =  np.empty([len(basin)])
    DA12[:] = np.nan
    DA22[:] = np.nan
    Openloop2[:] = np.nan
    modis2[:] = np.nan

    DA13 = np.empty([len(basin)])
    DA23 = np.empty([len(basin)])
    Openloop3 = np.empty([len(basin)])
    modis3 =  np.empty([len(basin)])
    DA13[:] = np.nan
    DA23[:] = np.nan
    Openloop3[:] = np.nan
    modis3[:] = np.nan

    DA14 = np.empty([len(basin)])
    DA24 = np.empty([len(basin)])
    Openloop4 = np.empty([len(basin)])
    modis4 =  np.empty([len(basin)])
    DA14[:] = np.nan
    DA24[:] = np.nan
    Openloop4[:] = np.nan
    modis4[:] = np.nan
 
    modis_avg1 = np.empty([SimulationDuration]) 
    model_avg1 = np.empty([SimulationDuration])
    DA_avg11 = np.empty([SimulationDuration])
    DA_avg21 = np.empty([SimulationDuration])
    model_avg1[:] = np.nan
    DA_avg11[:] = np.nan
    DA_avg21[:] = np.nan
    modis_avg1[:] = np.nan

    modis_avg2 = np.empty([SimulationDuration])
    model_avg2 = np.empty([SimulationDuration])
    DA_avg12 = np.empty([SimulationDuration])
    DA_avg22 = np.empty([SimulationDuration])
    model_avg2[:] = np.nan
    DA_avg12[:] = np.nan
    DA_avg22[:] = np.nan
    modis_avg2[:] = np.nan

    modis_avg3 = np.empty([SimulationDuration])
    model_avg3 = np.empty([SimulationDuration])
    DA_avg13 = np.empty([SimulationDuration])
    DA_avg23 = np.empty([SimulationDuration])
    model_avg3[:] = np.nan
    DA_avg13[:] = np.nan
    DA_avg23[:] = np.nan
    modis_avg3[:] = np.nan

    modis_avg4 = np.empty([SimulationDuration])
    model_avg4 = np.empty([SimulationDuration])
    DA_avg14 = np.empty([SimulationDuration])
    DA_avg24 = np.empty([SimulationDuration])
    model_avg4[:] = np.nan
    DA_avg14[:] = np.nan
    DA_avg24[:] = np.nan
    modis_avg4[:] = np.nan

    xx = np.empty([SimulationDuration])

    # North America [205 - 300o, 20 -70oN]
    # Europe [0 -45E, 40-70N]
    # Russia [45-150, 40-70]

    latmin = 40
    latmax = 70.0
    lonmin = 205.0
    lonmax = 300.0
    
    for SimDay in range(SimulationDuration):
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        xx[SimDay] = SimDay

        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second

        str_yr = ThisRestartDate.strftime('%Y')
        str_mh = ThisRestartDate.strftime('%m')
        str_dy = ThisRestartDate.strftime('%d')
        day_str = str_yr+str_mh+str_dy
          
# Read DA and open loop data files
        modisFile = modisFilePath+"MCD43C3.A"+day_str+".061.C768.nc"
        DAFile1 = DAFilePath1+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DAFile2 = DAFilePath2+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile = OpenloopFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec) 

        if (SimDay == 0):
            latout = ReadNETCDF(modisFile, "lat")
            lonout = ReadNETCDF(modisFile, "lon")

        if (VariName == "albedo_direct"):
            modis_array = ReadNETCDF(modisFile, "albedo_bsa_shortwave")
        else:
            modis_array = ReadNETCDF(modisFile, "albedo_wsa_shortwave")
       
        snowCover = ReadNETCDF(modisFile, "percent_snow")

        DA_array1 = ReadNETCDF_2D(DAFile1, VariName)
        DA_albedo1 =0.5*(DA_array1[0,0,:]+DA_array1[0,1,:])

        DA_array2 = ReadNETCDF_2D(DAFile2, VariName)
        DA_albedo2 =0.5*(DA_array2[0,0,:]+DA_array2[0,1,:])

        Openloop_array = ReadNETCDF_2D(openloopFile, VariName)
        open_albedo = 0.5*(Openloop_array[0,0,:]+Openloop_array[0,1,:])

        model_snow = ReadNETCDF(openloopFile, "snow_cover_fraction")

        for i in range(len(basin)):
            if (latout[i]>=latmin) & (latout[i]<=latmax) & (lonout[i]>=lonmin) & (lonout[i]<=lonmax):
                # modis snow and model snow
                if (modis_array[i] > 0.0 and snowCover[i] > 0.0 and model_snow[i] > 0.0): 
                    DA11[i] = DA_albedo1[i]
                    DA21[i] = DA_albedo2[i]
                    Openloop1[i] = open_albedo[i]
                    modis1[i] = modis_array[i]
                # modis no snow and model no snow
                if (modis_array[i] > 0.0 and snowCover[i] == 0.0 and model_snow[i] == 0.0):
                    DA12[i] = DA_albedo1[i]
                    DA22[i] = DA_albedo2[i]
                    Openloop2[i] = open_albedo[i]
                    modis2[i] = modis_array[i]
                # modis snow and model no snow  
                if (modis_array[i] > 0.0 and snowCover[i] > 0.0 and model_snow[i] == 0.0):
                    DA13[i] = DA_albedo1[i]
                    DA23[i] = DA_albedo2[i]
                    Openloop3[i] = open_albedo[i]
                    modis3[i] = modis_array[i]
                # modis no snow and model snow
                if (modis_array[i] > 0.0 and snowCover[i] == 0.0 and model_snow[i] > 0.0):
                    DA14[i] = DA_albedo1[i]
                    DA24[i] = DA_albedo2[i]
                    Openloop4[i] = open_albedo[i]
                    modis4[i] = modis_array[i]

        model_avg1[SimDay] = np.nanmean(Openloop1)
        DA_avg11[SimDay] = np.nanmean(DA11)
        DA_avg21[SimDay] = np.nanmean(DA21)
        modis_avg1[SimDay] = np.nanmean(modis1)

        model_avg2[SimDay] = np.nanmean(Openloop2)
        DA_avg12[SimDay] = np.nanmean(DA12)
        DA_avg22[SimDay] = np.nanmean(DA22)
        modis_avg2[SimDay] = np.nanmean(modis2)

        model_avg3[SimDay] = np.nanmean(Openloop3)
        DA_avg13[SimDay] = np.nanmean(DA13)
        DA_avg23[SimDay] = np.nanmean(DA23)
        modis_avg3[SimDay] = np.nanmean(modis3)

        model_avg4[SimDay] = np.nanmean(Openloop4)
        DA_avg14[SimDay] = np.nanmean(DA14)
        DA_avg24[SimDay] = np.nanmean(DA24)
        modis_avg4[SimDay] = np.nanmean(modis4)

    plt.figure(figsize=(15,8))

    ax1 = plt.subplot(2, 2, 1)
    ax1.set_ylim(0.0, 0.7)
    ax1.plot(xx,model_avg1,'b-',label="openloop")
    ax1.plot(xx,DA_avg11,'r-',label="DA GHCN")
    ax1.plot(xx,DA_avg21,'g-',label="DA IMS+GHCN")
    ax1.plot(xx,modis_avg1,'k-',label="MODIS") 
    ax1.set_ylabel("albedo", fontsize=15)
    ax1.set_xlabel("Days since 20191101", fontsize=15)
    ax1.set_title("(a) MODIS and Model Snow", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 2)
    ax1.set_ylim(0.0, 0.7)
    ax1.plot(xx,model_avg2,'b-',label="openloop")
    ax1.plot(xx,DA_avg12,'r-',label="DA GHCN")
    ax1.plot(xx,DA_avg22,'g-',label="DA IMS+GHCN")
    ax1.plot(xx,modis_avg2,'k-',label="MODIS")
    ax1.set_ylabel("albedo", fontsize=15)
    ax1.set_xlabel("Days since 20191101", fontsize=15)
    ax1.set_title("(b) MODIS and Model no Snow", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 3)
    ax1.set_ylim(0.0, 0.7)
    ax1.plot(xx,model_avg3,'b-',label="openloop")
    ax1.plot(xx,DA_avg13,'r-',label="DA GHCN")
    ax1.plot(xx,DA_avg23,'g-',label="DA IMS+GHCN")
    ax1.plot(xx,modis_avg3,'k-',label="MODIS")
    ax1.set_ylabel("albedo", fontsize=15)
    ax1.set_xlabel("Days since 20191101", fontsize=15)
    ax1.set_title("(c) MODIS snow and Model no Snow", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 4)
    ax1.set_ylim(0.0, 0.7)
    ax1.plot(xx,model_avg4,'b-',label="openloop")
    ax1.plot(xx,DA_avg14,'r-',label="DA GHCN")
    ax1.plot(xx,DA_avg24,'g-',label="DA IMS+GHCN")
    ax1.plot(xx,modis_avg4,'k-',label="MODIS")
    ax1.set_ylabel("albedo", fontsize=15)
    ax1.set_xlabel("Days since 20191101", fontsize=15)
    ax1.set_title("(d) MODIS no Snow and Model Snow", pad=20, fontsize=15)
    ax1.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(PlotPath, fname))
    plt.close()
    
BasinPath = "/scratch2/NCEPDEV/land/data/evaluation/basins/C768/GRDC_C768.nc"
PlotPath= "./plots/C768_multiYear/"

VarList=["albedo_direct", "albedo_diffuse"]
for var in VarList:
    Computediff(var)
