import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:]
    return MyArray

def Computediff(VariName):
    print(VariName)
    # open loop and DAghcn test
    DAFilePath="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/DA_GHCN_HR4/mem000/restarts/vector"
    OpenloopFilePath="/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/openloop_HR4/mem000/restarts/vector"
    cmcFilePath="/scratch2/NCEPDEV/land/data/evaluation/SNODAS/C768"
  
    RestartDate = "2020-01-21_00-00-00"
    DeltaTime = 1  # day
    # 1-112 (10/01/2019-01/20/2020 accumulation phase), 1-132 (01/21/2020 -05/31/202 melting phase)
    SimulationDuration = 132  # days
    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2
    Openloop_stdDF = pd.DataFrame()
    diff_averageDF=pd.DataFrame()
    diff2_rmseDF = pd.DataFrame()

    for SimDay in range(SimulationDuration):
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second
        DAFile = DAFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile = OpenloopFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        cmcFile = cmcFilePath+ "/SNODAS_snow.C768.%4d%02d%02d.nc" % (thisyear, thismonth, thisday)        

        DA_array = ReadNETCDF(DAFile , VariName)
        Openloop_array = ReadNETCDF(openloopFile , VariName)
        cmc_array = ReadNETCDF(cmcFile , VariName)
        
        DA=DA_array[0, :]
        Openloop=Openloop_array[0, :]-cmc_array[0,:]
        cmc=cmc_array[0,:]
        Diff = DA-cmc

        Openloop_stdDF["%d_Openloop"% SimDay]=Openloop**2
        diff_averageDF["%d_diff" % SimDay] = Diff
        diff2_rmseDF["%d_diff2" % SimDay] = Diff**2

    #Openloop_stdDF["stdOfOpenloop"]=Openloop_stdDF.std(axis=1)
    diff_averageDF["MeanOFDiffs"] = diff_averageDF.mean(axis=1)
    Openloop_stdDF["MeanOFOpenloop"]=Openloop_stdDF.mean(axis=1)
    diff2_rmseDF["MeanOFDiff2"] = diff2_rmseDF.mean(axis=1)
    
    diff2_rmseDF["RMSEOFDiff2"] = np.sqrt(diff2_rmseDF["MeanOFDiff2"])-np.sqrt(Openloop_stdDF["MeanOFOpenloop"])

    diff2_rmseDF["MeanOFDiffs"] = diff_averageDF["MeanOFDiffs"]

    Long = ReadNETCDF(SpatialPath, "longitude")
    Lat = ReadNETCDF(SpatialPath, "latitude")
    diff2_rmseDF["Lat"] = Lat
    diff2_rmseDF["Long"] = Long
    diff2_rmseDF["ID"] = np.arange(1, len(diff2_rmseDF) + 1)
    # Openloop_stdDF.to_csv("test.csv", index=False)

    # maxvalue = diff2_rmseDF["MeanOFDiffs"].max()
    # maxvalueID=np.argmax(diff2_rmseDF["MeanOFDiffs"])
    # print("maxMeanID=%s"%maxvalueID)

    # Read elevation in
    elevation_array = ReadNETCDF(SpatialPath, "elevation")
    lat_array = ReadNETCDF(SpatialPath, "latitude") 
    lon_array = ReadNETCDF(SpatialPath, "longitude")
   
    #now we want to mask ploes locations that are where veg_type==15 ( glaciers  - we're not doing any DA here)
    veg_type_arryay = ReadNETCDF(SpatialPath, "vegetation_category")
    CurrentMeanResult=diff2_rmseDF.MeanOFDiffs
    MaskedMeanResult=np.where(veg_type_arryay==15,0,CurrentMeanResult)
    maxvalue = MaskedMeanResult.max()
    maxvalue=round(maxvalue,3)
    minvalue = MaskedMeanResult.min()
    minvalue=round(minvalue,3)

    CurrentRMSEResult = diff2_rmseDF.RMSEOFDiff2
    MaskedRMSEResult = np.where(veg_type_arryay == 15, 0, CurrentRMSEResult)
    maxRMSE = MaskedRMSEResult.max()
    maxRMSE = round(maxRMSE, 3)
    minRMSE = MaskedRMSEResult.min()
    minRMSE= round(minRMSE,3)

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree(central_longitude=0))
    ax.add_feature(cfeature.GSHHSFeature(scale='auto'))
    cs = plt.scatter(x=diff2_rmseDF.Long, y=diff2_rmseDF.Lat, c=MaskedRMSEResult, cmap="bwr", marker=',', s=10, vmax=40,
                     vmin=-40, transform=ccrs.PlateCarree())
    cb = plt.colorbar(cs, orientation='horizontal', shrink=0.5, pad=.04)
    cb.set_label("%s RMSE Difference between DAghcn-SNODAS and openloop-SNODAS "%VariName, fontsize=12)
    ax.set_extent([-107, -103, 37, 41])
    ax.set_xticks([i for i in range(-107,-103,1)])
    ax.set_yticks([i for i in range(37,41,1)])
    plt.title("Max=%s Min=%s" % (maxRMSE, minRMSE), fontsize=12)

    ax = fig.add_subplot(2, 1, 2, projection=ccrs.PlateCarree(central_longitude=0))
    ax.add_feature(cfeature.GSHHSFeature(scale='auto'))
    cs = plt.scatter(lon_array, lat_array, c=elevation_array, cmap="YlGn", marker=',', s=10, vmax=4500,
                     vmin=500, transform=ccrs.PlateCarree())
    cb = plt.colorbar(cs, orientation='horizontal', shrink=0.5, pad=.04)
    cb.set_label("Elevation (mm)", fontsize=12)
    ax.set_extent([-107, -103, 37, 41])
    ax.set_xticks([i for i in range(-107,-103,1)])
    ax.set_yticks([i for i in range(37,41,1)])

    plt.savefig(os.path.join(PlotPath, "%s_RMSE_Difference_DAghcn_Openloop_SNODAS_US_smallRegion_21JAN-31May2020.png"%VariName))
    plt.close()

SpatialPath = "/scratch2/NCEPDEV/land/data/forcing/era5/static/C768/ufs-land_C768_hr3_static_fields.nc"
PlotPath= "./plots/HR4C768_DAimsGHCN/"


VarList=["snow_depth"]
for var in VarList:
    Computediff(var)
