import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:].ravel()
    return MyArray

def Computediff(VariName):
    print(VariName)
    DAFilePath1="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/DA_GHCN_2015to2020/mem000/restarts/vector"
    DAFilePath2="/scratch2/NCEPDEV/stmp3/Youlong.Xia/landDA/cycle_land/C768/DA_imsGHCN_2015to2020/mem000/restarts/vector"
    OpenloopFilePath="/scratch1/NCEPDEV/stmp4/Youlong.Xia/landDA/cycle_land/C768/openloop_ghcnHOFX_2015to2020/mem000/restarts/vector"

    RestartDate = "2019-10-01_00-00-00"
    DeltaTime = 1  # day
    SimulationDuration = 365  # days

    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2

#   defined the basin mask
    basin = ReadNETCDF(BasinPath, "basins")
    DA1 = np.empty([len(basin)])
    DA2 = np.empty([len(basin)])
    Openloop = np.empty([len(basin)])

    DA1[:] = np.nan
    DA2[:] = np.nan
    Openloop[:] = np.nan 

    model_avg1 = np.empty([SimulationDuration])
    DA_avg11 = np.empty([SimulationDuration])
    DA_avg21 = np.empty([SimulationDuration])
    model_avg1[:] = np.nan
    DA_avg11[:] = np.nan
    DA_avg21[:] = np.nan

    model_avg2 = np.empty([SimulationDuration])
    DA_avg12 = np.empty([SimulationDuration])
    DA_avg22 = np.empty([SimulationDuration])
    model_avg2[:] = np.nan
    DA_avg12[:] = np.nan
    DA_avg22[:] = np.nan

    model_avg3 = np.empty([SimulationDuration])
    DA_avg13 = np.empty([SimulationDuration])
    DA_avg23 = np.empty([SimulationDuration])
    model_avg3[:] = np.nan
    DA_avg13[:] = np.nan
    DA_avg23[:] = np.nan
 
    model_avg4 = np.empty([SimulationDuration])
    DA_avg14 = np.empty([SimulationDuration])
    DA_avg24 = np.empty([SimulationDuration])
    model_avg4[:] = np.nan
    DA_avg14[:] = np.nan
    DA_avg24[:] = np.nan

    xx = np.empty([SimulationDuration])
    
    for SimDay in range(SimulationDuration):
        #print(SimDay)
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        xx[SimDay] = SimDay

        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second
          
# Read DA and open loop data files

        DAFile1 = DAFilePath1+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DAFile2 = DAFilePath2+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile = OpenloopFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
      
        # for snow depth (mm) 
        DA_array1 = ReadNETCDF(DAFile1, VariName)
        DA_array2 = ReadNETCDF(DAFile2, VariName)
        Openloop_array = ReadNETCDF(openloopFile, VariName)

        for i in range(len(basin)):
            if basin[i] == 6202:
                DA1[i] = DA_array1[i]
                DA2[i] = DA_array2[i]
                Openloop[i] =Openloop_array[i]
        
        DAavg1 = np.nanmean(DA1)
        DAavg2 = np.nanmean(DA2)
        modelavg = np.nanmean(Openloop)
        model_avg1[SimDay] = modelavg 
        DA_avg11[SimDay] = DAavg1
        DA_avg21[SimDay] = DAavg2

        # for snow water equivalent (mm)
        DA_array1 = ReadNETCDF(DAFile1, "snow_water_equiv")
        DA_array2 = ReadNETCDF(DAFile2, "snow_water_equiv")
        Openloop_array = ReadNETCDF(openloopFile, "snow_water_equiv")
        
        for i in range(len(basin)):
            if basin[i] == 6202:
                DA1[i] = DA_array1[i]
                DA2[i] = DA_array2[i]
                Openloop[i] =Openloop_array[i]

        DAavg1 = np.nanmean(DA1)
        DAavg2 = np.nanmean(DA2)
        modelavg = np.nanmean(Openloop)
        model_avg2[SimDay] = modelavg
        DA_avg12[SimDay] = DAavg1
        DA_avg22[SimDay] = DAavg2

        # for surface runoff (mm/s) => mm/day
        DA_array1 = ReadNETCDF(DAFile1, "runoff_surface")
        DA_array2 = ReadNETCDF(DAFile2, "runoff_surface")
        Openloop_array = ReadNETCDF(openloopFile, "runoff_surface")

        for i in range(len(basin)):
            if basin[i] == 6202:
                DA1[i] = DA_array1[i]
                DA2[i] = DA_array2[i]
                Openloop[i] =Openloop_array[i]

        DAavg1 = np.nanmean(DA1)
        DAavg2 = np.nanmean(DA2)
        modelavg = np.nanmean(Openloop)
        model_avg3[SimDay] = 86400*modelavg
        DA_avg13[SimDay] = 86400*DAavg1
        DA_avg23[SimDay] = 86400*DAavg2

        # for runoff baseflow (mm/s) => mm/day
        DA_array1 = ReadNETCDF(DAFile1, "runoff_baseflow")
        DA_array2 = ReadNETCDF(DAFile2, "runoff_baseflow")
        Openloop_array = ReadNETCDF(openloopFile, "runoff_baseflow")

        for i in range(len(basin)):
            if basin[i] == 6202:
                DA1[i] = DA_array1[i]
                DA2[i] = DA_array2[i]
                Openloop[i] =Openloop_array[i]

        DAavg1 = np.nanmean(DA1)
        DAavg2 = np.nanmean(DA2)
        modelavg = np.nanmean(Openloop)
        model_avg4[SimDay] = 86400*modelavg
        DA_avg14[SimDay] = 86400*DAavg1
        DA_avg24[SimDay] = 86400*DAavg2


    plt.figure(figsize=(15,8))

    ax1 = plt.subplot(2, 2, 1)
    ax1.set_ylim(0.0, 150.0)
    ax1.plot(xx,DA_avg11,'r-',label="GHCN DA")
    ax1.plot(xx,model_avg1,'b-',label="openloop")
    ax1.plot(xx,DA_avg21,'g-',label="IMS+GHCN DA")
    ax1.set_ylabel("SNOD (mm)", fontsize=15)
    ax1.set_xlabel("Days since 20191001", fontsize=15)
    ax1.set_title("(a) Snow Depth", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 2)
    ax1.set_ylim(0.0, 30.0)
    ax1.plot(xx,DA_avg12,'r-',label="GHCN DA")
    ax1.plot(xx,model_avg2,'b-',label="openloop")
    ax1.plot(xx,DA_avg22,'g-',label="IMS+GHCN DA")
    ax1.set_ylabel("SWE (mm)", fontsize=15)
    ax1.set_xlabel("Days since 20191001", fontsize=15)
    ax1.set_title("(b) Snow Water Equivalent", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 3)
    ax1.set_ylim(0.0, 2.0)
    ax1.plot(xx,DA_avg13,'r-',label="GHCN DA")
    ax1.plot(xx,model_avg3,'b-',label="openloop")
    ax1.plot(xx,DA_avg23,'g-',label="IMS+GHCN DA")
    ax1.set_ylabel("ssrun (mm/day)", fontsize=15)
    ax1.set_xlabel("Days since 20191001", fontsize=15)
    ax1.set_title("(c) Surface Runoff", pad=20, fontsize=15)
    ax1.legend()

    ax1 = plt.subplot(2, 2, 4)
    ax1.set_ylim(0.0, 2.0)
    ax1.plot(xx,DA_avg14,'r-',label="GHCN DA")
    ax1.plot(xx,model_avg4,'b-',label="openloop")
    ax1.plot(xx,DA_avg24,'g-',label="IMS+GHCN DA")
    ax1.set_ylabel("runBase (mm/day)", fontsize=15)
    ax1.set_xlabel("Days since 20191001", fontsize=15)
    ax1.set_title("(c) Baseflow Runoff", pad=20, fontsize=15)
    ax1.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(PlotPath, fname))
    plt.close()

BasinPath = "/scratch2/NCEPDEV/land/data/evaluation/basins/C768/GRDC_C768.nc"
PlotPath= "./plots/C768_multiYear/"
fname = "Danube_river_basin_sweRunoff_da_openloop_2020.png"


VarList=["snow_depth"]
for var in VarList:
    Computediff(var)
