import netCDF4
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
import sys
from datetime import datetime
from datetime import timedelta

def ReadNETCDF(FilePath, GroupName, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.groups[GroupName].variables[VariName][:]
    f.close()
    return MyArray

group = "ombg"
varname = "totalSnowDepth" 

daHofxPath="/scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/DA_era5/DA/hofx"
mdlHofxPath="/scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/openloop_era5/DA/hofx"
PlotPath="/scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/DA_era5/DA/plots/ombg_snap/"
fname = PlotPath+group+"_"+varname+"_WestCanada_timeseries_WY2020.png"

RestartDate = "2019-10-01"
DeltaTime = 1  # day
SimulationDuration = 244  # days
RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d')

model_avg = np.empty([SimulationDuration])
DA_avg = np.empty([SimulationDuration])
ABS_diff = np.empty([SimulationDuration])
zero = np.empty([SimulationDuration])

model_avg[:] = np.nan
DA_avg[:] = np.nan
ABS_diff[:] = np.nan
zero[:] = 0.0

xx = np.empty([SimulationDuration])

# western CONUS [-125 - -110oW, 30 -50oN]
# Europe [30-70E, 40-60N]
# North GP [100-80W, 40-50N]
# Western Canada [155-125W, 50-65N]

latmin = 50.0
latmax = 65.0
lonmin = -155.0
lonmax = -125.0

for SimDay in range(SimulationDuration):
    PassedTimeStepNo = (SimDay)
    ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
    print(ThisRestartDate)
    str_yr = ThisRestartDate.strftime('%Y')
    str_mh = ThisRestartDate.strftime('%m')
    str_dy = ThisRestartDate.strftime('%d')
    date = str_yr+str_mh+str_dy+"18"

    xx[SimDay] = SimDay
        
    for i in range(0,6):

        HofxFile1 = daHofxPath+"/letkf_hofx_ghcn_"+date+"_000%d.nc"%i
        HofxFile2 = mdlHofxPath+"/letkf_hofx_ghcn_"+date+"_000%d.nc"%i
        #print(HofxFile1)
        #print(HofxFile2)

        if i==0:
            ombg01 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg02 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat0 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon0 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

        if i==1:
            ombg11 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg12 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat1 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon1 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

        if i==2:
            ombg21 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg22 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat2 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon2 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

        if i==3:
            ombg31 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg32 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat3 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon3 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

        if i==4:
            ombg41 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg42 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat4 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon4 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

        if i==5:
            ombg51 = ReadNETCDF(HofxFile1 ,"ombg","totalSnowDepth")
            ombg52 = ReadNETCDF(HofxFile2 ,"ombg","totalSnowDepth")
            lat5 = ReadNETCDF(HofxFile1 ,"MetaData", "latitude")
            lon5 = ReadNETCDF(HofxFile1 ,"MetaData", "longitude")

    latout = np.concatenate((lat0, lat1, lat2, lat3, lat4, lat5), axis = None)
    lonout = np.concatenate((lon0, lon1, lon2, lon3, lon4, lon5), axis = None)
    da1 = np.empty([len(latout)])
    da2 = np.empty([len(latout)])
    da1[:] = np.nan
    da2[:] = np.nan

    dataout1 = np.concatenate((ombg01, ombg11, ombg21, ombg31, ombg41, ombg51), axis = None)     
    dataout2 = np.concatenate((ombg02, ombg12, ombg22, ombg32, ombg42, ombg52), axis = None)

    #print(len(latout), len(lonout), len(dataout1), len(dataout2))
    # fill values for a region
    for k in range(len(latout)):
        if (latout[k]>=latmin) & (latout[k]<=latmax) & (lonout[k]>=lonmin) & (lonout[k]<=lonmax):
            da1[k] = dataout1[k]
            da2[k] = dataout2[k]
    avg1= np.nanmean(da1)
    avg2= np.nanmean(da2)
    DA_avg[SimDay] = avg1
    model_avg[SimDay] = avg2
    ABS_diff[SimDay] = abs(avg1)-abs(avg2)

# plot global averaged snow depth time series
plt.figure(figsize=(12,8))
ax1 = plt.subplot(1, 1, 1)
ax1.set_ylim(-900, 100.0)
ax1.plot(xx,model_avg,'r-',label="openloop: O-F")
ax1.plot(xx,DA_avg,'b-',label="DA: O-F")
ax1.plot(xx,ABS_diff,'g-',label="ombg: abs(DA)-abs(openloop)")
ax1.plot(xx,zero,'k-')
ax1.set_ylabel("Snow Depth (mm)", fontsize=20)
ax1.set_xlabel("simulation days since 1OCT2019", fontsize=20)
ax1.legend()
plt.tight_layout()
plt.savefig(fname)
plt.close()
