"""Managing the :class:`met_processing` pipeline for the environmental suitability run."""

import datetime as dt
import json
import logging
import os
import shutil
from string import Template

import pandas as pd

from met_processing.common.params_file_parser import ParamsFileParser
from met_processing.runner.common import job_runner

MAX_WORKERS: int = 10

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('Processor.pipeline')

###############################################################################

def loadConfig(configFile):
    '''Load a json config file.'''

    logger.info(f"Loading config file: {configFile}")

    try:
        with open(configFile) as config_file:
            config = json.load(config_file)
    except:
        logger.exception(f"Failure opening config file {configFile}")
        raise

    return config


def getParameter(config, parameter):
    '''Get a parameter from a config.'''

    logger.info(f"Getting {parameter} from the config")

    try:
        result = config[parameter]
    except:
        logger.exception(f"No {parameter} in the config file!")
        raise

    return result


def generate_temporal_points(file, datestr, timeresolution, nDaysForecast):
    '''Create the time series excluding the end time.'''

    datenow = dt.datetime(int(datestr[:4]), int(datestr[4:6]), int(datestr[6:8]), 0, 0)

    starttime = datenow + dt.timedelta(hours=timeresolution)
    endtime = datenow + dt.timedelta(days=nDaysForecast) + dt.timedelta(hours=timeresolution)
    timediff = dt.timedelta(hours=timeresolution)

    timeseries = []
    curr = starttime

    outfile = open(file, "w")
    # Generate the time series excluding the end time.
    while curr < endtime:
        timeseries.append(curr)
        outDateString = "{}\n".format(curr.strftime("%Y%m%d%H%M"))
        outfile.write(outDateString)
        curr += timediff

    outfile.close()
    return outfile


def clean(workPath):
    '''Clean temporary files and folders from the working directory.'''

    try:
        logger.info(f"Clean temporary files and folders from {workPath}")
        shutil.rmtree(workPath + 'extraction', ignore_errors=True)
        shutil.rmtree(workPath + 'chunks', ignore_errors=True)
        shutil.rmtree(workPath + 'post_processing', ignore_errors=True)
    except:
        logger.exception(f"Some failure when cleaning work directory", exc_info=True)
        raise

    return


# def generate_all(sys_config, run_config):
#     # Write run_config.json
#     workPath = getParameter(run_config,'OUTPUT_DIR')
#     run_configName = 'run_config.json'
#     run_configFile = workPath + run_configName
#
#     with open(run_configFile, 'w') as run_configJson:
#         json.dump(run_config, run_configJson, indent=4)
#
#     run_configJson.close()
#
#     # Run all generate
#     try:
#         job_runner.generate_all_jobs(run_config, sys_config)
#     except Exception:
#         logger.exception(f"Some failure when running one of the generate job", exc_info=True)
#         raise
#
#     return


def run_extraction(run_params: dict, sys_params: dict):
    '''Run weather data extraction with the :class:`met_processing.runner.common.job_runner` package.'''

    logger.info(f"Running regridding in multi process mode.")
    job_runner.run_extraction(run_params, sys_params)
    logger.info('Data extracted and chunked')


def run_post_processing(run_params: dict, sys_params: dict, processor_name: str):
    '''Run post processing with the :class:`met_processing.runner.common.job_runner` package.'''

    logger.info(f"Running post-processing.")
    job_runner.run_post_processing(run_params, sys_params, processor_name)
    logger.info('Data post processing is completed')


# def run_merger(run_params: dict, sys_params: dict, processor_name: str):
#     try:
#         job_runner.run_merge_post_processing(run_params, sys_params, processor_name)
#     except Exception:
#         logger.exception(f"Some failure when running merge RIE", exc_info=True)
#         raise


#######################################
#lawrence coment back to original (prevent_overwrite=True)
def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent_overwrite = True):
    '''
    Run the whole :class:`met_processing` pipeline for environmental suitability.
    '''
    # The prevent_overwrite parameter can be set to False if you want to re-run a job in-place.

    # Get parameters from the config
    resourcesPath = getParameter(pipeline_config,'RESOURCES_PATH')
    workPath = getParameter(pipeline_config,'WORK_PATH') + 'ENVIRONMENT_2.0_' + dateString + '/'
    if not os.path.exists(workPath):
        os.makedirs(workPath)

    inPath = getParameter(pipeline_config,'INPUT_PATH')
    outPath = getParameter(pipeline_config,'OUTPUT_PATH')

    runType = getParameter(pipeline_config,'RUN_TYPE')
    nDayExtraction = getParameter(pipeline_config,'EXTRACTION_DAYS')
    nDayForecast = getParameter(pipeline_config,'FORECAST_DAYS')

    sys_config_file = getParameter(pipeline_config,'SYS_CONFIG')
    sys_config = loadConfig(sys_config_file)

    templateName = 'template_' + runType + '_config.json'
    template_configFile = resourcesPath + 'configs/' + templateName
    config = loadConfig(template_configFile)

    # Before writing any files, check the output path doesn't exist already
    # We might expect outPath to exist already, but not the processed subfolder
    region_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region)
    if prevent_overwrite: assert not os.path.exists(region_outPath)

    # Generate extraction (input) and output temporal points files
    timeresolution = 3 # hours

    extraction_temporal_points_file = workPath + 'extraction_temporal_points.csv'
    try:
        logger.info(f"Generate extraction temporal points to: {extraction_temporal_points_file}")
        if prevent_overwrite: assert not os.path.exists(extraction_temporal_points_file)
        generate_temporal_points(extraction_temporal_points_file, dateString, timeresolution, nDayExtraction)
    except:
        logger.exception(f"Some failure when generate {extraction_temporal_points_file}", exc_info=True)
    # extraction_temporal_points = pd.read_csv(extraction_temporal_points_file)

    output_temporal_points_file = workPath + 'output_temporal_points.csv'
    try:
        logger.info(f"Generate output temporal points to: {output_temporal_points_file}")
        if prevent_overwrite: assert not os.path.exists(output_temporal_points_file)
        generate_temporal_points(output_temporal_points_file, dateString, timeresolution, nDayForecast)
    except:
        logger.exception(f"Some failure when generate {output_temporal_points_file}", exc_info=True)
    output_temporal_points = pd.read_csv(output_temporal_points_file)
    temporal_dim = output_temporal_points.shape[0]

    # Modify run_config
    config[ParamsFileParser.TIMEPOINTS_FILE_KEY] = extraction_temporal_points_file
    config[ParamsFileParser.OUTPUT_DIR_KEY] = workPath
    # config[ParamsFileParser.SPATIAL_POINTS_FILE_KEY] = input_spatial_points_file
    # note that this field will only get used by the ewa5_met_data_extraction code, which uses a multi-processor module
    config[ParamsFileParser.MAX_PROCESSORS_KEY] = MAX_WORKERS

    config['FIELD_NAME_CONSTANTS_PATH'] = getParameter(pipeline_config,'FIELD_NAME_CONSTANTS')

    if (runType == 'operational'):
        config['NCDF_DIR_PATH'] = inPath + 'ENVIRONMENT_2.0_' + dateString + '/NAME_Met_as_netcdf/'
    else:
        config['NCDF_DIR_PATH'] = inPath

    ## START RUNS ####################
    # os.chdir(workPath)

    N_processor = range(len(config['POST_PROCESSING']['PROCESSORS']))
    logger.info(f"Find {N_processor} processor")

    for p in N_processor:
        processor_name = config['POST_PROCESSING']['PROCESSORS'][p]['PROCESSOR_NAME']
        if processor_name != 'RIE':
            config['POST_PROCESSING']['PROCESSORS'][p]['TIMEPOINTS_FILE_PATH'] = extraction_temporal_points_file

            # Clean if extraction is not done
            # if (extracted == False):
            #     clean(workPath)

            # generate_all(sys_config, config)

            # Extract
            # if (extracted == False):
            #     run_extraction(config, sys_config)
            #     extracted = True

            logger.info(f"Starting {processor_name} post processor ---------------------------------")
            run_post_processing(config, sys_config, processor_name)

            # run_merger(config, sys_config, processor_name)
        else:
            strains = getParameter(pipeline_config, 'STRAINS')

            config['POST_PROCESSING']['PROCESSORS'][p]['TIMEPOINTS_FILE_PATH'] = output_temporal_points_file

            for strain in strains:
                # Modify strain specific suitability parts of the config
                if pipeline_config['PARAMS'][strain]['future_steps'] > 0:
                    for i in range(len(config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'])):
                        config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'][i]['DURATION'] = pipeline_config['PARAMS'][strain]['future_steps']
                else:
                    for i in range(len(config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'])):
                        config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'][i]['ENABLED'] = "FALSE"

                config['POST_PROCESSING']['PROCESSORS'][p]['PARAMS']['suitability_modules'] = pipeline_config['PARAMS'][strain]['suitability_modules']
                config['POST_PROCESSING']['PROCESSORS'][p]['PARAMS']['thresholds'] = pipeline_config['PARAMS'][strain]['thresholds']

                # Clean if extraction is not done
                # if (extracted == False):
                #     clean(workPath)

                # generate_all(sys_config, config)

                # Extract
                # if (extracted == False):
                #     run_extraction(config, sys_config)
                #     extracted = True

                logger.info(f"Starting {strain} suitability ---------------------------------")
                envSuitPath = workPath + 'post_processing/RIE/'
                run_post_processing(config, sys_config, processor_name)

                # run_merger(config, sys_config, processor_name)

                resultFile = envSuitPath + 'RIE.nc'
                strain_outPath = os.path.join(region_outPath,strain)
                strain_outFile = strain_outPath + '/RIE_value.nc'

                # Check results dimension
                # result = pd.read_csv(resultFile)
                # result_dims = result.shape

                """
                read in the input_spatial_points.csv file and get the number of spatial points - this is used for 
                sense-checking the dimensions of the output
                """
                # todo we could swap the spatial points file for a file specifying the expected dimensions - much smaller
                region_spatial_points_file = resourcesPath + 'assets/' + 'input_spatial_points_' + region + '.csv'
                spatial_points = pd.read_csv(region_spatial_points_file)
                spatial_dim = spatial_points.shape[0]

                # if ((result_dims[0] != spatial_dim) or (result_dims[1] != (temporal_dim + 4))): # + 4 required because there are extra columns in the result file
                #     logger.error(f"Result dimension {result_dims} does not match with the expected: ({spatial_dim}, {temporal_dim + 4})")
                #     raise IndexError

                if not os.path.exists(strain_outPath):
                    os.makedirs(strain_outPath)

                shutil.copy(resultFile,strain_outFile)

                # todo - Add a flag to this part of the code to enable/disable csv writing as an option
                # resultCSVFile = envSuitPath + 'RIE.csv'
                # if os.path.isfile(resultCSVFile):
                #     strain_outFile = strain_outPath + '/RIE_value.csv'
                #     shutil.copy(resultCSVFile,strain_outFile)

                logger.info(f"{strain} result successfully created and moved to {strain_outPath}/")

    logger.info('SUCCESSFULLY FINISHED')