FAQ | This is a LIVE service | Changelog

Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
EnvSuitPipeline.py 10.54 KiB
import datetime as dt
import json
import logging
import os
import shutil

import pandas as pd

from met_processing.runner.common import job_runner

MAX_WORKERS: int = 10

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('Processor.pipeline')

###############################################################################

def loadConfig(configFile):
    logger.info(f"Loading config file: {configFile}")

    try:
        with open(configFile) as config_file:
            config = json.load(config_file)
    except:
        logger.exception(f"Failure opening config file {configFile}")
        raise

    return config


def getParameter(config, parameter):
    logger.info(f"Getting {parameter} from the config")

    try:
        result = config[parameter]
    except:
        logger.exception(f"No {parameter} in the config file!")
        raise

    return result


def generate_temporal_points(file, datestr, timeresolution, nDaysForecast):
    datenow = dt.datetime(int(datestr[:4]), int(datestr[4:6]), int(datestr[6:8]), 0, 0)

    starttime = datenow + dt.timedelta(hours=timeresolution)
    endtime = datenow + dt.timedelta(days=nDaysForecast) + dt.timedelta(hours=timeresolution)
    timediff = dt.timedelta(hours=timeresolution)

    timeseries = []
    curr = starttime

    outfile = open(file, "w")
    # Generate the time series excluding the end time.
    while curr < endtime:
        timeseries.append(curr)
        outDateString = "{}\n".format(curr.strftime("%Y%m%d%H%M"))
        outfile.write(outDateString)
        curr += timediff

    outfile.close()
    return outfile


def clean(workPath): # Clean temporary files and folders from the working directory
    try:
        logger.info(f"Clean temporary files and folders from {workPath}")
        shutil.rmtree(workPath + 'extraction', ignore_errors=True)
        shutil.rmtree(workPath + 'chunks', ignore_errors=True)
        shutil.rmtree(workPath + 'post_processing', ignore_errors=True)
    except:
        logger.exception(f"Some failure when cleaning work directory", exc_info=True)
        raise

    return


def generate_all(sys_config, run_config):
    # Write run_config.json
    workPath = getParameter(run_config,'OUTPUT_DIR')
    run_configName = 'run_config.json'
    run_configFile = workPath + run_configName

    with open(run_configFile, 'w') as run_configJson:
        json.dump(run_config, run_configJson, indent=4)

    run_configJson.close()

    # Set slurm mode
    slurm_mode = bool(getParameter(run_config,'SLURM'))

    # Run all generate
    try:
        job_runner.generate_all_jobs(run_config, sys_config, slurm_mode)
    except Exception:
        logger.exception(f"Some failure when running one of the generate job", exc_info=True)
        raise

    return


def run_extraction(work_path):
    logger.info(f"Running regridding in multi process mode.")
    job_runner.run_extraction(work_path, **{"MAX_WORKERS": MAX_WORKERS})
    logger.info('Data extracted and chunked')


def run_post_processing(work_path):
    logger.info(f"Running post-processing in multi process mode.")
    job_runner.run_post_processing(work_path, **{"MAX_WORKERS": MAX_WORKERS})
    logger.info('Data extracted and chunked')


def run_merger(work_path):
    try:
        job_runner.run_merge_post_processing(work_path)
    except Exception:
        logger.exception(f"Some failure when running merge RIE", exc_info=True)
        raise


#######################################

def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent_overwrite = True):
    '''
    The prevent_overwrite parameter can be set to False if you want to re-run
    a job in-place.
    '''
    # Get parameters from the config
    resourcesPath = getParameter(pipeline_config,'RESOURCES_PATH')
    workPath = getParameter(pipeline_config,'WORK_PATH') + 'ENVIRONMENT_2.0_' + dateString + '/'
    if not os.path.exists(workPath):
        os.makedirs(workPath)

    inPath = getParameter(pipeline_config,'INPUT_PATH')
    outPath = getParameter(pipeline_config,'OUTPUT_PATH')

    runType = getParameter(pipeline_config,'RUN_TYPE')
    nDayExtraction = getParameter(pipeline_config,'EXTRACTION_DAYS')
    nDayForecast = getParameter(pipeline_config,'FORECAST_DAYS')

    sys_config_file = getParameter(pipeline_config,'SYS_CONFIG')
    sys_config = loadConfig(sys_config_file)

    templateName = 'template_' + runType + '_config.json'
    template_configFile = resourcesPath + 'configs/' + templateName
    config = loadConfig(template_configFile)

    # Before writing any files, check the output path doesn't exist already
    # We might expect outPath to exist already, but not the processed subfolder
    region_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region)
    if prevent_overwrite: assert not os.path.exists(region_outPath)

    # Get spatial points file for the region
    region_spatial_points_file = resourcesPath + 'assets/' + 'input_spatial_points_' + region + '.csv'
    input_spatial_points_file = workPath + 'input_spatial_points.csv'
    if prevent_overwrite: assert not os.path.exists(input_spatial_points_file)
    shutil.copy(region_spatial_points_file,input_spatial_points_file)
    spatial_points = pd.read_csv(input_spatial_points_file)
    spatial_dim = spatial_points.shape[0]

    # Generate extraction (input) and output temporal points files
    timeresolution = 3 # hours

    extraction_temporal_points_file = workPath + 'extraction_temporal_points.csv'
    try:
        logger.info(f"Generate extraction temporal points to: {extraction_temporal_points_file}")
        if prevent_overwrite: assert not os.path.exists(extraction_temporal_points_file)
        generate_temporal_points(extraction_temporal_points_file, dateString, timeresolution, nDayExtraction)
    except:
        logger.exception(f"Some failure when generate {extraction_temporal_points_file}", exc_info=True)
    # extraction_temporal_points = pd.read_csv(extraction_temporal_points_file)

    output_temporal_points_file = workPath + 'output_temporal_points.csv'
    try:
        logger.info(f"Generate output temporal points to: {output_temporal_points_file}")
        if prevent_overwrite: assert not os.path.exists(output_temporal_points_file)
        generate_temporal_points(output_temporal_points_file, dateString, timeresolution, nDayForecast)
    except:
        logger.exception(f"Some failure when generate {output_temporal_points_file}", exc_info=True)
    output_temporal_points = pd.read_csv(output_temporal_points_file)
    temporal_dim = output_temporal_points.shape[0]

    # Modify run_config
    config['TIMEPOINTS_FILE_PATH'] = extraction_temporal_points_file
    config['OUTPUT_DIR'] = workPath
    config['SPATIAL_POINTS_FILE_PATH'] = input_spatial_points_file

    config['FIELD_NAME_CONSTANTS_PATH'] = getParameter(pipeline_config,'FIELD_NAME_CONSTANTS')

    if (runType == 'operational'):
        config['NCDF_DIR_PATH'] = inPath + 'ENVIRONMENT_2.0_' + dateString + '/NAME_Met_as_netcdf/'
    else:
        config['NCDF_DIR_PATH'] = inPath

    ## START RUNS ####################
    os.chdir(workPath)

    N_processor = range(len(config['POST_PROCESSING']['PROCESSORS']))
    logger.info(f"Find {N_processor} processor")

    for p in N_processor:
        processor_name = config['POST_PROCESSING']['PROCESSORS'][p]['PROCESSOR_NAME']
        if processor_name != 'RIE':
            config['POST_PROCESSING']['PROCESSORS'][p]['TIMEPOINTS_FILE_PATH'] = extraction_temporal_points_file

            # Clean if extraction is not done
            if (extracted == False):
                clean(workPath)

            generate_all(sys_config, config)

            # Extract
            if (extracted == False):
                run_extraction(workPath)
                extracted = True

            logger.info(f"Starting {processor_name} post processor ---------------------------------")
            processorPath = workPath + 'post_processing/' + processor_name + '/'
            run_post_processing(processorPath)

            run_merger(processorPath)
        else:
            strains = getParameter(pipeline_config, 'STRAINS')

            config['POST_PROCESSING']['PROCESSORS'][p]['TIMEPOINTS_FILE_PATH'] = output_temporal_points_file

            for strain in strains:
                # Modify strain specific suitability parts of the config
                if pipeline_config['PARAMS'][strain]['future_steps'] > 0:
                    for i in range(len(config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'])):
                        config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'][i]['DURATION'] = pipeline_config['PARAMS'][strain]['future_steps']
                else:
                    for i in range(len(config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'])):
                        config['POST_PROCESSING']['PROCESSORS'][p]['FUTURE_FIELDS'][i]['ENABLED'] = "FALSE"

                config['POST_PROCESSING']['PROCESSORS'][p]['PARAMS']['suitability_modules'] = pipeline_config['PARAMS'][strain]['suitability_modules']
                config['POST_PROCESSING']['PROCESSORS'][p]['PARAMS']['thresholds'] = pipeline_config['PARAMS'][strain]['thresholds']

                # Clean if extraction is not done
                if (extracted == False):
                    clean(workPath)

                generate_all(sys_config, config)

                # Extract
                if (extracted == False):
                    run_extraction(workPath)
                    extracted = True

                logger.info(f"Starting {strain} suitability ---------------------------------")
                envSuitPath = workPath + 'post_processing/RIE/'
                run_post_processing(envSuitPath)

                run_merger(envSuitPath)

                resultFile = envSuitPath + 'RIE.csv'
                strain_outPath = os.path.join(region_outPath,strain)
                strain_outFile = strain_outPath + '/RIE_value.csv'

                # Check results dimension
                result = pd.read_csv(resultFile)
                result_dims = result.shape

                if ((result_dims[0] != spatial_dim) or (result_dims[1] != (temporal_dim + 4))): # + 4 required because there are extra columns in the result file
                    logger.error(f"Result dimension {result_dims} is not match with the expectation ({spatial_dim}, {temporal_dim + 4})")
                    raise IndexError

                if not os.path.exists(strain_outPath):
                    os.makedirs(strain_outPath)

                shutil.copy(resultFile,strain_outFile)
                logger.info(f"{strain} result successfully created and moved to {strain_outPath}/")

    logger.info('SUCCESSFULLY FINISHED')