import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:]
    return MyArray

def ReadNETCDF1d(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:].ravel()
    return MyArray

def Computediff(VariName):
    print(VariName)
    DAFilePath1="/scratch2/NCEPDEV/land/Michael.Barlage/spinup_hr3/with_DA/C384"
    NospinupPath="/scratch2/NCEPDEV/land/Michael.Barlage/spinup_hr3/no_spinup/C384"
    cciPath="/scratch2/NCEPDEV/land/data/evaluation/SNODAS/vector/C384"

    RestartDate = "2018-01-03_00-00-00"
    DeltaTime = 7  # day
    SimulationDuration = 91  # days, 91 for the period

    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2

#   defined the basin mask
    basin = ReadNETCDF(BasinPath, "basins")
    DA1 = np.empty([len(basin)])
    nspin = np.empty([len(basin)])
    cci = np.empty([len(basin)])
    cci_error = np.empty([len(basin)])

    DA1[:] = np.nan
    nspin[:] = np.nan
    cci[:] = np.nan 
    cci_error[:] = np.nan

    DA_aaray = np.empty([len(basin)])
    nspin_array = np.empty([len(basin)])
    DA_aaray[:] = np.nan
    nspin_array[:] = np.nan
   
    cci_avg = np.empty([SimulationDuration])
    cci_up = np.empty([SimulationDuration])
    cci_bt = np.empty([SimulationDuration])
    DA_avg = np.empty([SimulationDuration])
    nspin_avg = np.empty([SimulationDuration])  

    DA_avg[:] = np.nan
    nspin_avg[:] = np.nan
    cci_avg[:] = np.nan
    cci_up[:] = np.nan
    cci_bt[:] = np.nan

    xx = np.empty([SimulationDuration])
    
    for SimDay in range(SimulationDuration):
        #print(SimDay)
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        xx[SimDay] = SimDay

        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second
          
# Read DA and open loop data files

        DAFile1 = DAFilePath1+ "/ufs_land_restart_anal.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        NspinFile = NospinupPath+ "/nospinup.landIC.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        cciFile = cciPath+ "/SNODAS_snow.C384.%4d%02d%02d.nc" % (thisyear, thismonth, thisday)
        
        # for snow water equivalent
        DA_array1 = ReadNETCDF(DAFile1, VariName)
        nspin_array1 = ReadNETCDF(NspinFile, "snow_depth")
        cci_array = ReadNETCDF1d(cciFile, "snow_water_equivalent")
        
        DA_array=DA_array1[0, :].ravel()
        nspin_array=nspin_array1[0, :].ravel() 

        for i in range(len(basin)):
            if basin[i] == 4405:
                if cci_array[i] >= 0.0:
                    DA1[i] = DA_array[i]
                    nspin[i] = nspin_array[i]
                    cci[i] =cci_array[i]

        DAavg = np.nanmean(DA1)
        nspinavg = np.nanmean(nspin)
        cciavg = np.nanmean(cci)
        
        DA_avg[SimDay] = DAavg
        nspin_avg[SimDay] = nspinavg
        cci_avg[SimDay] = cciavg

    plt.figure(figsize=(12,8))
    ax1 = plt.subplot(1, 1, 1)
    ax1.set_ylim(0.0, 80)
    ax1.plot(xx,DA_avg,'r-',label="Spinup")
    ax1.plot(xx,nspin_avg,'m-',label="No Spinup")
    ax1.plot(xx,cci_avg,'k-',label="snodas") 
    ax1.set_ylabel("Snow Water Equivalent (mm)", fontsize=20)
    ax1.set_xlabel("Number of Seven-Day since 03 January 2018", fontsize=20)
    ax1.set_title("Colorado River Basin", pad=20, fontsize=20)
    ax1.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(PlotPath, fname))
    plt.close()

BasinPath = "/scratch2/NCEPDEV/land/data/evaluation/basins/C384/GRDC_C384.nc"
PlotPath= "./plots/"
fname = "Colorado_river_basin_spinup_nospinup_snodasSWE_Jan2018-Spe2019.png"

VarList=["snow_water_equiv"]
for var in VarList:
    Computediff(var)
