import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:].ravel()
    return MyArray

def Computediff(VariName):
    print(VariName)

    # DA snow parameter table tests
    #DAFilePath1="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_GHCN_baseline/mem000/restarts/vector"
    #DAFilePath2="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_GHCN_newSnowtbl/mem000/restarts/vector"
    #DAFilePath3="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_imsGHCN_baseline/mem000/restarts/vector"
    #DAFilePath4="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_imsGHCN_newSnowtbl/mem000/restarts/vector"
    
     #open loop alebedo, snow thermal conductivity, HR4 test11
    DAFilePath1="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/openloop_HR4_ncarsnowtblV1/mem000/restarts/vector/"
    DAFilePath2="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/openloop_Tksno/mem000/restarts/vector"
    DAFilePath3="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/openloop_albedoTksno/mem000/restarts/vector"
    DAFilePath4="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/openloop_HR4test11/mem000/restarts/vector"

    OpenloopFilePath1="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_openloop_baseline/mem000/restarts/vector"
    OpenloopFilePath2="/scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/C768/snowParDA/DA_openloop_newSnowtbl/mem000/restarts/vector"
    cmcFilePath="/scratch2/NCEPDEV/land/data/evaluation/CMC/C768"

    RestartDate = "2019-10-01_00-00-00"
    DeltaTime = 1  # day
    SimulationDuration = 244  # days -366 for a year

    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2

#   defined the array
    vegetation = ReadNETCDF(SpatialPath, "vegetation_category")
    basin = ReadNETCDF(SpatialPath, "land_mask")
    DA1 = np.empty([len(basin)])
    DA2 = np.empty([len(basin)])
    DA3 = np.empty([len(basin)])
    DA4 = np.empty([len(basin)])
    Openloop1 = np.empty([len(basin)])
    Openloop2 = np.empty([len(basin)])
    CMC = np.empty([len(basin)])

    DA1[:] = np.nan
    DA2[:] = np.nan
    DA3[:] = np.nan
    DA4[:] = np.nan

    Openloop1[:] = np.nan
    Openloop2[:] = np.nan 
    CMC[:] = np.nan

    CMC_avg = np.empty([SimulationDuration])
    model_avg1 = np.empty([SimulationDuration])
    model_avg2 = np.empty([SimulationDuration])
    DA_avg1 = np.empty([SimulationDuration])
    DA_avg2 = np.empty([SimulationDuration])
    DA_avg3 = np.empty([SimulationDuration])
    DA_avg4 = np.empty([SimulationDuration])

    model_avg1[:] = np.nan
    model_avg2[:] = np.nan
    DA_avg1[:] = np.nan
    DA_avg2[:] = np.nan
    DA_avg3[:] = np.nan
    DA_avg4[:] = np.nan
    CMC_avg[:] = np.nan

    xx = np.empty([SimulationDuration])
    veg_type_arryay = ReadNETCDF(SpatialPath, "vegetation_category")
    
    # North Central[255-265W, 40-50N]
    # Asia [45-80E, 40-40N]
    latmin = 40.0
    latmax = 50.0
    lonmin = 45.0
    lonmax = 80.0

    for SimDay in range(SimulationDuration):
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        xx[SimDay] = SimDay

        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second
          
# Read DA and open loop data files

        DAFile1 = DAFilePath1+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DAFile2 = DAFilePath2+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DAFile3 = DAFilePath3+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DAFile4 = DAFilePath4+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile1 = OpenloopFilePath1+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile2 = OpenloopFilePath2+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        cmcFile = cmcFilePath+ "/C768_cmc_snow_depth_%4d%02d%02d00.nc" % (thisyear, thismonth, thisday)
        
        DA_array1 = ReadNETCDF(DAFile1, VariName)
        DA_array2 = ReadNETCDF(DAFile2, VariName)
        DA_array3 = ReadNETCDF(DAFile3, VariName)
        DA_array4 = ReadNETCDF(DAFile4, VariName)
        Openloop_array1 = ReadNETCDF(openloopFile1, VariName)
        Openloop_array2 = ReadNETCDF(openloopFile2, VariName)
        cmc_array = ReadNETCDF(cmcFile, 'snowDepth')

        if (SimDay == 0):
            latout = ReadNETCDF(cmcFile, "lat")
            lonout = ReadNETCDF(cmcFile, "lon")

        for i in range(len(basin)):
            if (latout[i]>=latmin) & (latout[i]<=latmax) & (lonout[i]>=lonmin) & (lonout[i]<=lonmax):
                if vegetation[i] == 10:
                    if (cmc_array[i] >= 0.0) & (cmc_array[i] <= 2000.0):
                        DA1[i] = DA_array1[i]
                        DA2[i] = DA_array2[i]
                        DA3[i] = DA_array3[i]
                        DA4[i] = DA_array4[i]

                        Openloop1[i] =Openloop_array1[i]
                        Openloop2[i] =Openloop_array2[i]
                        CMC[i] = cmc_array[i]

        modelavg1 = np.nanmean(Openloop1)
        modelavg2 = np.nanmean(Openloop2)
        DAavg1 = np.nanmean(DA1)
        DAavg2 = np.nanmean(DA2)
        DAavg3 = np.nanmean(DA3)
        DAavg4 = np.nanmean(DA4)

        model_avg1[SimDay] = modelavg1
        model_avg2[SimDay] = modelavg2
        DA_avg1[SimDay] = DAavg1
        DA_avg2[SimDay] = DAavg2
        DA_avg3[SimDay] = DAavg3
        DA_avg4[SimDay] = DAavg4

        cmcavg = np.nanmean(CMC)
        CMC_avg[SimDay] = cmcavg

    plt.figure(figsize=(12,8))
    ax1 = plt.subplot(1, 1, 1)
    ax1.set_ylim(0.0, 300)
    ax1.plot(xx,DA_avg1,'r--',label="openloop HR4ncarsnowtblV1")
    ax1.plot(xx,model_avg1,'b--',label="openloop baseline")
    ax1.plot(xx,DA_avg3,'g--',label="openloop albedoTksno")
    ax1.plot(xx,DA_avg2,'r-',label="openloop Tksno")
    ax1.plot(xx,model_avg2,'b-',label="openloop snowtbl")
    ax1.plot(xx,DA_avg4,'g-',label="openloop HR4test11")
    ax1.plot(xx,CMC_avg,'k-',label="CMC")
    ax1.set_ylabel("Snow Depth (mm)", fontsize=20)
    ax1.set_xlabel("Simulation days since 01 October 2019", fontsize=20)
    ax1.set_title("Asia [45-80E, 40-50N]", pad=20, fontsize=20)
    ax1.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(PlotPath, fname))
    plt.close()

SpatialPath = "/scratch2/NCEPDEV/land/data/forcing/era5/static/C768/ufs-land_C768_hr3_static_fields.nc"
PlotPath= "./plots/C768_snowPar/"
fname = "Asia_openloop_HR4ncarsnowtblV1_snowDepth_CMC_2019-2020.png"

VarList=["snow_depth"]
for var in VarList:
    Computediff(var)
