#ProcessorEpidemiology.py
'''Functions to process the epidemiology component.'''

import datetime
from glob import glob
import json
import logging
from pathlib import Path
import os
import shutil

from numpy import argmax, unique
from pandas import read_csv, DataFrame, to_datetime
from rasterio import open as rio_open

# gitlab projects
# TODO: Package these projects so they are robust for importing
from EpiModel import ( # created by rs481
    EpiAnalysis,
    EpiModel,
    EpiPrep,
    EpiPrepLister,
    EpiPrepLoader,
    plotRaster
)
from ews_postprocessing.epi.epi_post_processor import EPIPostPostProcessor

from ProcessorUtils import (
        open_and_check_config,
        get_only_existing_globs,
        endJob,
        add_filters_to_sublogger,
        query_past_successes,
        short_name,
        disease_latin_name_dict
)

logger = logging.getLogger('Processor.Epi')
add_filters_to_sublogger(logger)

def calc_epi_date_range(init_str,span_days=[0,6]):
    '''Date range is  determined relative to init_date.
    span_days is usually defined in the job config file. Day zero is current
    day, negative values point to past (historical or analysis) days, and
    positive values point to forecast days.
    Returns a start_date and end_date.'''

    init_date = datetime.datetime.strptime(init_str,'%Y%m%d')

    # note that filename date represents preceding 3 hours, so day's data
    #  starts at file timestamp 0300 UTC
    threehour_shift = datetime.timedelta(hours=3)

    # add 24hrs so that final day is fully included
    day_shift = datetime.timedelta(days=1)

    # if more than 999 days
    if len(str(span_days[0]))>3:
        # assume it is a date string
        start_date = datetime.datetime.strptime(span_days[0]+'0300','%Y%m%d%H%M')
    else:
        date_shift0 = datetime.timedelta(days=span_days[0])

        start_date = init_date + date_shift0 + threehour_shift

    if len(str(span_days[1]))>3:
        # assume it is a date string
        end_date = datetime.strptime(span_days[1]+'0000','%Y%m%d%H%M')

        end_date = end_date + day_shift
    else:
        date_shift1 = datetime.timedelta(days=span_days[1])

        end_date = init_date + date_shift1 +day_shift

    return start_date, end_date

def process_pre_job_epi(input_args):
    '''Returns a boolean as to whether the job is ready for full processing.'''

    logger.info('started process_pre_job_epi()')

    # check pre-requisite jobs are complete
    query_past_successes(input_args)

    config_fns = input_args.config_paths

    for configFile in config_fns:

        # they should be working if the script made it this far, no need to try
        config_i = open_and_check_config(configFile)

        #determine end time, from config file
        arg_start_date = input_args.start_date
        calc_span_days = config_i['Epidemiology']['CalculationSpanDays']
        assert len(calc_span_days) == 2

        start_time, end_time = calc_epi_date_range(arg_start_date,calc_span_days)

        # warn if it is a long timespan
        date_diff = end_time - start_time
        if date_diff.days > 100:
            logger.warning("More than 100 days will be calculated over, likely longer than any single season")

    return True


def create_epi_config_string(config,jobPath,startString,endString):

    configtemplate_fn = config['ConfigFilePath']
    configName_withoutEpi = f"{os.path.basename(configtemplate_fn).replace('.json','')}_{startString}-{endString}"

    # create a string describing every epi calc configuration
    epiStrings = []
    for epiconf in config['Epidemiology']['Epi']:
        epiKwargsString = ''.join([f"{k}{v}" for k,v in epiconf['modelArguments'].items()])

        # drop any repetitive elements of kwarg
        epiKwargsString = epiKwargsString.replace('infectionprevious','')
        epiKwargsString = epiKwargsString.replace('capbeta','cb')

        epiCaseString = f"{epiconf['model'].lower()}{epiKwargsString}"

        # provide to configuration for output filename
        epiconf["infectionRasterFileName"] = f"{jobPath}/infections_{configName_withoutEpi}_{epiCaseString}"

        epiStrings += [epiCaseString]

    epiString = '-'.join(epiStrings)

    config_filename = f"{configName_withoutEpi}_{epiString}"

    logger.debug(f"length of config filename is {len(config_filename)}.")

    if len(config_filename) > 254:
        logger.info(f"filename length is too long, it will raise an OSError, using a short form instead")

        # epi cases are not described in filename, an interested user
        # must look in the json file for details.
        config_filename = configName_withoutEpi

        assert len(config_filename) <= 254

    return config_filename


def raster_to_csv(raster_fn,csv_fn):

    # create a csv version and save in the job directory,
    # to compare host raster with dep and env suit
    # note this can be time-varying by providing additional rows
    with rio_open(raster_fn,'r') as host_raster:
        host_arr = host_raster.read(1)
        shape = host_raster.shape

        # determine coordinates
        coords = [host_raster.xy(i,j) for i in range(shape[0]) for j in range(shape[1])]
        lons = unique([ci[0] for ci in coords])
        lats = unique([ci[1] for ci in coords])
        assert shape == (lats.size,lons.size)

    # build into a dataframe
    # (rasters start in the top left, so descending latitude coordinates)
    host_df = DataFrame(data=host_arr,index=lats[::-1],columns=lons)
    # rearrange to ascending latitude corodinates
    host_df.sort_index(axis='rows',inplace=True)
    # make spatial coordinates a multi-index, like for dep and env suit csvs
    host_series = host_df.stack()
    # for now, provide a nominal date of validity to enable a time column
    # so far, using mapspam which is a static map, so time is irrelevant
    host_series.name = '201908150000'
    host_df2 = DataFrame(host_series).T

    host_df2.to_csv(csv_fn)

    return

def process_in_job_epi(jobPath,status,config,component):
    logger.info('started process_in_job_epi()')

    # TODO: Some of this is modifying config before epi model is run. Determine
    # how to account for that

    # initialise any needed variables

    reference_date_str = config['StartString']
    reference_date = datetime.datetime.strptime(reference_date_str,'%Y%m%d')

    start_date, end_date = calc_epi_date_range(reference_date_str,config['Epidemiology']['CalculationSpanDays'])

    date_diff = end_date - start_date

    start_string = start_date.strftime('%Y-%m-%d-%H%M')
    start_string_short = start_date.strftime('%Y%m%d%H%M')
    end_string = end_date.strftime('%Y-%m-%d-%H%M')

    # update config accordingly
    config['ReferenceTime'] = reference_date_str
    config['StartTime'] = start_string
    config['StartTimeShort'] = start_string_short
    config['EndTime'] = end_string

    diseases = config['Epidemiology']['DiseaseNames']

    def gather_deposition(config_epi,config,variable_name,start_date,end_date,jobDataPath,status):

        # TODO: Simplify the set of required arguments . Check if config is necessary.

        config_epi['Deposition']['VariableName'] = variable_name # disease_latin_name_dict[disease]+'_DEPOSITION'

        config_epi['Deposition']['FileNamePrepared'] = f"{jobDataPath}/data_input_deposition.csv"

        # Use config-defined file lister in config file instead of here
        file_lister_dep_name = config_epi['Deposition'].get('FileListerFunction',None)

        # when it isn't defined, guess what it should be
        if file_lister_dep_name is None:

            file_lister_dep_name = 'list_deposition_files_operational'

            if date_diff > datetime.timedelta(days=7):

                file_lister_dep_name = 'list_deposition_files_historical'
                logger.info('Using historical method to prepare data on spore deposition')

        file_lister_dep = getattr(EpiPrepLister,file_lister_dep_name)

        config_for_lister = config.copy()
        config_for_lister.update(config_epi)

        # get bounds of host map, to exclude redundant deposition datapoints
        hostRasterFileName = config_for_lister["Host"]["HostRaster"]
        with rio_open(hostRasterFileName) as hostRaster:
            bounds = hostRaster.bounds

        lister_kwargs = {}
        lister_kwargs['reference_date']=config['ReferenceTime']

        loader_kwargs= {}
        loader_kwargs['VariableName']= config_for_lister['Deposition'].get('VariableName')
        loader_kwargs['VariableNameAlternative']= config_for_lister['Deposition'].get('VariableNameAlternative')
        loader_kwargs['bounds'] = bounds

        try:

            EpiPrep.prep_input(config_for_lister,start_date,end_date,
                    component='Deposition',
                    file_lister=file_lister_dep,
                    file_loader=EpiPrepLoader.load_NAME_file,
                    lister_kwargs=lister_kwargs,
                    **loader_kwargs)

            assert os.path.isfile(config_epi['Deposition']['FileNamePrepared'])

        except:

            logger.exception(f"Unexpected error in deposition data preparation")
            status.reset('ERROR')
            endJob(status,premature=True)

        return

    # get list of variable names to be loaded from deposition input
    depo_variable_names =  config['Epidemiology']['Deposition']['VariableNames']
    assert len(depo_variable_names) == len(diseases)

    # loop over each sub region

    region = config['RegionName']
    #for region in config['SubRegionNames']:

    for disease in diseases:

        assert disease in disease_latin_name_dict.keys()

        config_epi = config['Epidemiology'].copy()

        # TODO: CAUTION: Any iterations (e.g. disease or sub-region) are hidden
        # in jobPath, and not retained in the config file. This is a provlem for
        # process_EWS_plotting_epi which receives a single config file and must
        # try a fudge to retrieve details for each iteration.
        # This should be improved, either by making the one config file
        # aware of all of the iterations, or looping over iterations in
        # Processor.py with one iteration-specific config.
        case_specific_path = f"{jobPath}/{region}/{disease}/"
        Path(case_specific_path).mkdir(parents=True, exist_ok=True)

        logger.info(f"Preparing for epidemiology calc of {disease} in {region}")

        # create config_filename to describe job configuration
        config_filename = create_epi_config_string(config,case_specific_path,start_string,end_string)

        # prepare a directory for input data
        jobDataPath = f"{case_specific_path}/input_data/"
        Path(jobDataPath).mkdir(parents=True, exist_ok=True)

        # configure filename of prepared deposition data

        if 'Deposition' in config_epi:

            # determine which variable name to load for this disease
            disease_idx = [i for i,j in enumerate(diseases) if j==disease][0]

            variable_name = depo_variable_names[disease_idx]

            gather_deposition(config_epi,config,variable_name,start_date,end_date,jobDataPath,status)

        # configure filename of prepared deposition data

        if 'Environment' in config_epi:

            logger.info('Preparing environmental suitability data')

            config_epi['SubRegionName'] = region

            config_epi['DiseaseName'] = disease

            config_epi['Environment']['FileNamePrepared'] = f"{jobDataPath}/data_input_environment.csv"

            # Use config-defined file lister in config file instead of here
            file_lister_env_name = config_epi['Environment'].get('FileListerFunction',None)

            # when it isn't defined, guess what it should be
            if file_lister_env_name is None:

                use_monthly_chunk=False # hard-coded for historical analysis
                file_lister_env_name = 'list_env_suit_files_operational'

                if (date_diff > datetime.timedelta(days=7)) & ('ENVIRONMENT_2.0' in config_epi['Environment']['PathTemplate']) & use_monthly_chunk:

                    logger.info('Using monthly-chunk method to prepare data on environmental suitability')
                    file_lister_env_name = 'list_env_suit_files_historical_monthlychunk'

                elif date_diff > datetime.timedelta(days=7):

                    logger.info('Using historical method to prepare data on environmental suitability')
                    file_lister_env_name = 'list_env_suit_files_historical'

            file_lister_env = getattr(EpiPrepLister,file_lister_env_name)

            config_for_lister = config.copy()
            config_for_lister.update(config_epi)

            try:

                EpiPrep.prep_input(config_for_lister,start_date,end_date,
                        component='Environment',
                        file_loader=EpiPrepLoader.load_env_file,
                        file_lister=file_lister_env)

                assert os.path.isfile(config_epi['Environment']['FileNamePrepared'])

            except:

                logger.exception(f"Unexpected error in env data preparation")
                status.reset('ERROR')
                endJob(status,premature=True)

        # prepare a copy of the host data

        logger.info('Preparing a copy of the host raster data')

        src_host = config_epi['Host']['HostRaster']
        fn_host = os.path.basename(src_host)
        dst_host = f"{jobDataPath}/{fn_host}"

        # copy the tif to the job directory and refer to that instead
        shutil.copyfile(src_host,dst_host)
        config_epi['Host']['HostRaster'] = dst_host

        logger.info('Preparing a copy of the host data as csv')

        dst_host_csv = dst_host.replace('.tif','.csv')

        raster_to_csv(dst_host,dst_host_csv)

        config_epi['Host']['HostCSV'] = dst_host_csv

        # provide fundamental config elements to config_epi
        for k,v in config.items():
            if k not in short_name.keys():
                config_epi[k]=v

        logger.debug('Incremental configuration looks like:')
        def print_item(item):
            logger.debug(f"Item {item}")
            logger.debug(json.dumps(item,indent=2))
        def iterate(items):
            for item in items.items():
                if hasattr(item,'items'):
                    # iterate
                    iterate(item)
                else:
                    print_item(item)
        iterate(config_epi)

        logger.debug('Complete configuration looks like:')
        logger.debug(json.dumps(config_epi,indent=2))

        # write the complete configuration file to job directory
        with open(f"{case_specific_path}/{config_filename}.json",'w') as write_file:
            json.dump(config_epi,write_file,indent=4)

        # run epi model

        try:
            EpiModel.run_epi_model(f"{case_specific_path}/{config_filename}.json")
        except:
            logger.exception('Unexpected error in EpiModel')
            raise

        # perform calc on output

        def calc_total(arr):
            return 'total', arr.sum()

        def calc_max(arr):
            return 'maximum', arr.max()

        def calc_mean(arr):
            return 'mean', arr.mean()

        for epiconf in config['Epidemiology']['Epi']:

            outfile = epiconf["infectionRasterFileName"]

            with rio_open(outfile+'.tif','r') as infectionRaster:
                infection = infectionRaster.read(1)

                # define function to quantify overall result, for easy check
                # TODO: Create a more meaningful result?
                # TODO: make this configurable
                analysis_func = calc_mean

                analysis_desc, analysis_value = analysis_func(infection)

                logger.info(f"For case {outfile}")
                logger.info('Infection {:s} is {:.2e}'.format( analysis_desc, analysis_value))

                # to save tif as png for easy viewing
                logger.debug('Saving tif output as png for easier viewing')
                plotRaster.save_raster_as_png(outfile)

        # comparison figure

        # TODO: make this plot configurable? with function or args?
        #logger.info('Plotting epi output alongside contributing components')
        # figure_func = getattr(EpiAnalysis,'plot_compare_host_env_dep_infection')
        logger.info('Plotting composite image of epi formulations')
        figure_func = getattr(EpiAnalysis,'plot_compare_epi_cases')

        # isolate the config for this function, in case of modifications
        config_epi_for_comparison = config_epi.copy()

        fig,axes,cases = figure_func(
                config_epi_for_comparison,
                start_str = start_string,
                end_str = end_string)

        SaveFileName = f"{case_specific_path}/EPI_{config_filename}_comparison"

        fig.savefig(SaveFileName+'.png',dpi=300)

        # slice the epi results into before forecast and in forecast

        for epiconf in config['Epidemiology']['Epi']:

            outfile = epiconf["infectionRasterFileName"]+'_progression.csv'

            fn_seasonsofar = epiconf["infectionRasterFileName"]+'_seasonsofar.csv'
            fn_weekahead = epiconf["infectionRasterFileName"]+'_weekahead.csv'

            # load the full epi results
            df_full = read_csv(outfile,header=[0],index_col=[0,1])
            column_date_fmt = f"X{config['StartTimeShort']}_X%Y%m%d%H%M"
            df_full_dates = to_datetime(df_full.columns.astype('str'),format=column_date_fmt)

            # determine date to cut with
            # plus 1 minute so midnight is associated with preceding day
            date_to_cut = datetime.datetime.strptime(config['StartString']+'0001','%Y%m%d%H%M')
            dates_after_cut = df_full_dates >= date_to_cut
            idx = argmax(dates_after_cut)-1

            # build seasonsofar dataframe (only need the last date)
            df_seasonsofar = df_full.iloc[:,idx]

            # check column name is defined as expected
            # from epi start time to forecast start time
            column_name = f"X{config['StartTimeShort']}_X{config['StartString']}0000"
            assert df_seasonsofar.name == column_name

            #  save to csv
            df_seasonsofar.to_csv(fn_seasonsofar,header=True,index=True)

            # build weekahead dataframe and save to csv
            df_fc_start = df_full.iloc[:,idx]
            df_fc_start_name = df_fc_start.name.split('_')[-1]

            df_fc_end = df_full.iloc[:,-1]
            df_fc_end_name = df_fc_end.name.split('_')[-1]

            df_weekahead = df_fc_end - df_fc_start

            # defined column name
            df_weekahead.name = '_'.join([df_fc_start_name,df_fc_end_name])

            # save to csv
            df_weekahead.to_csv(fn_weekahead,header=True,index=True)

    return

def process_EWS_plotting_epi(jobPath,config):
    '''Returns a list of output files for transfer.'''

    logger.info('started process_EWS_plotting_epi()')

    # initalise necessary variables from config

    start_date, end_date = calc_epi_date_range(config['StartString'],config['Epidemiology']['CalculationSpanDays'])

    start_string = start_date.strftime('%Y%m%d')
    end_string = end_date.strftime('%Y%m%d')

    epi_case_operational = config['Epidemiology']['EWS-Plotting']['EpiCase']

    if epi_case_operational == 'none':
        logger.info('Config specifies not to call to EWS-Plotting')
        return []

    diseases = config['Epidemiology']['DiseaseNames']

    # initialise environment
    sys_config = config['Epidemiology']['EWS-Plotting']['SysConfig']

    chart_config = config['Epidemiology']['EWS-Plotting']['ChartConfig']

    # use the first matching epi formulation
    # TODO: Is there a more efficient way to select?
    epi_filename = [ce['infectionRasterFileName'] for ce in config['Epidemiology']['Epi'] if ce['model']==epi_case_operational][0]

    dep_regionnames = ['SouthAsia','Ethiopia']

    # TODO get deposition_dir from config['Epidemiology']['Deposition']['PathTemplate']
    dep_regionname = 'Ethiopia' #SouthAsia

    deposition_dir = f"{config['WorkspacePath']}DEPOSITION_{start_string}/WR_NAME_{dep_regionname}_{start_string}/"

    # TODO: handle multiple diseases and regions in Processor as a loop, or in the config
    deposition_disease_name = [disease_latin_name_dict[disease]+'_DEPOSITION' for disease in diseases][0]

    ews_plot_dir = f"{jobPath}/plotting/"

    Path(ews_plot_dir).mkdir(parents=True, exist_ok=True)

    # loop over diseases
    EWSPlottingOutputGlobs = []
    for disease in diseases:
        disease_short = disease.lower().replace('rust','')

        # a fudge, guess disease type
        # because config['Epidemiology']['ProcessInJob'] handles disease loop internally
        # assumes disease name is the last directory before the filename
        # TODO: handle multiple diseases and regions in Processor as a loop, or in the config
        disease_to_drop = os.path.dirname(epi_filename).split('/')[-1].replace('Rust','')
        disease_to_add = disease.replace('Rust','')
        epi_filename = epi_filename.replace(disease_to_drop,disease_to_add)

        map_title = "Integrated prediction of Wheat $\\bf{" + disease_to_add + "}$ Rust infection"
        if 'PlottingRegionName' not in config['Epidemiology']['EWS-Plotting']:
            plotting_region_name_lower = config['RegionName'].lower()
        else:
            plotting_region_name_lower = config['Epidemiology']['EWS-Plotting']['PlottingRegionName'].lower()

        run_config = config['Epidemiology']['EWS-Plotting']['RunConfig_seasonsofar']

        logger.info(f"Running EWS-Plotting with the following configs:\n{sys_config}\n{run_config}\n{chart_config}")

        epi_processor = EPIPostPostProcessor()
        epi_processor.set_param_config_files(sys_params_file_arg=sys_config,
                                        chart_params_file_arg=chart_config,
                                        run_params_file_arg=run_config,
                                        epi_input_csv_arg=epi_filename+'_seasonsofar.csv',
                                        disease_type_arg=disease_short+'_seasontodate',
                                        issue_date_arg=start_string,
                                        output_dir_arg=ews_plot_dir,
                                        wheat_sources_dir_arg=deposition_dir,
                                        wheat_source_disease_name_arg=deposition_disease_name,
                                        map_title_arg=map_title,
                                        chart_area_prefix=plotting_region_name_lower)
        epi_processor.process()

        # prepare command for seasonplusforecast

        run_config = config['Epidemiology']['EWS-Plotting']['RunConfig_seasonplusforecast']

        logger.info(f"Running EWS-Plotting with the following configs:\n{sys_config}\n{run_config}\n{chart_config}")

        epi_processor_2 = EPIPostPostProcessor()
        epi_processor_2.set_param_config_files(sys_params_file_arg=sys_config,
                                        chart_params_file_arg=chart_config,
                                        run_params_file_arg=run_config,
                                        epi_input_csv_arg=epi_filename+'.csv', # for seasonplusforecast
                                        #epi_input_csv_arg=epi_filename+'_weekahead.csv', # for weekahead
                                        disease_type_arg=disease_short+'_seasonincforecast',
                                        issue_date_arg=start_string,
                                        output_dir_arg=ews_plot_dir,
                                        wheat_sources_dir_arg=deposition_dir,
                                        wheat_source_disease_name_arg=deposition_disease_name,
                                        map_title_arg=map_title,
                                        chart_area_prefix=plotting_region_name_lower)
        epi_processor_2.process()

        # check the output
        EWSPlottingOutputDir = f"{ews_plot_dir}/images/"
        # TODO: Make this smarter, connected to the results of EWSPlottingEPIBase.plot_epi()
        EWSPlottingOutputGlobs += [f"{EWSPlottingOutputDir}infection_{plotting_region_name_lower}_*{disease_short}*.png"]

        EWSPlottingOutputGlobs = get_only_existing_globs(EWSPlottingOutputGlobs,inplace=False)

        # check there is some output from EWS-plotting
        if not EWSPlottingOutputGlobs:
            logger.error('EWS-Plotting did not produce any output')
            raise RuntimeError

    # provide to list for transfer
    EWSPlottingOutputs = [item for EWSPlottingOutput in EWSPlottingOutputGlobs for item in glob(EWSPlottingOutput)]

    return EWSPlottingOutputs