From e296c1c9c767dbe9eec1cd79c663d0037c9d5408 Mon Sep 17 00:00:00 2001 From: Jake Smith <jws52@cam.ac.uk> Date: Thu, 30 Jun 2022 14:47:35 +0100 Subject: [PATCH] feat: Env suit will not overwrite unless prompted --- EnvSuitPipeline.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/EnvSuitPipeline.py b/EnvSuitPipeline.py index 191f6f4..666b95e 100644 --- a/EnvSuitPipeline.py +++ b/EnvSuitPipeline.py @@ -161,7 +161,11 @@ def run_merger(workPath): ####################################### -def run_pipeline(pipeline_config, region, dateString, extracted = False): +def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent_overwrite = True): + ''' + The prevent_overwrite parameter can be set to False if you want to re-run + a job in-place. + ''' # Get parameters from the config resourcesPath = getParameter(pipeline_config,'RESOURCES_PATH') workPath = getParameter(pipeline_config,'WORK_PATH') + 'ENVIRONMENT_2.0_' + dateString + '/' @@ -182,9 +186,15 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False): template_configFile = resourcesPath + templateName config = loadConfig(template_configFile) + # Before writing any files, check the output path doesn't exist already + # We might expect outPath to exist already, but not the processed subfolder + region_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region) + if prevent_overwrite: assert not os.path.exists(region_outPath) + # Get spatial points file for the region region_spatial_points_file = resourcesPath + 'input_spatial_points_' + region + '.csv' input_spatial_points_file = workPath + 'input_spatial_points.csv' + if prevent_overwrite: assert not os.path.exists(input_spatial_points_file) shutil.copy(region_spatial_points_file,input_spatial_points_file) spatial_points = pd.read_csv(input_spatial_points_file) spatial_dim = spatial_points.shape[0] @@ -195,6 +205,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False): extraction_temporal_points_file = workPath + 'extraction_temporal_points.csv' try: logger.info(f"Generate extraction temporal points to: {extraction_temporal_points_file}") + if prevent_overwrite: assert not os.path.exists(extraction_temporal_points_file) generate_temporal_points(extraction_temporal_points_file, dateString, timeresolution, nDayExtraction) except: logger.exception(f"Some failure when generate {extraction_temporal_points_file}", exc_info=True) @@ -203,6 +214,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False): output_temporal_points_file = workPath + 'output_temporal_points.csv' try: logger.info(f"Generate output temporal points to: {output_temporal_points_file}") + if prevent_overwrite: assert not os.path.exists(output_temporal_points_file) generate_temporal_points(output_temporal_points_file, dateString, timeresolution, nDayForecast) except: logger.exception(f"Some failure when generate {output_temporal_points_file}", exc_info=True) @@ -283,7 +295,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False): run_merger(envSuitPath) resultFile = envSuitPath + 'RIE.csv' - strain_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region,strain) + strain_outPath = os.path.join(region_outPath,strain) strain_outFile = strain_outPath + '/RIE_value.csv' # Check results dimension -- GitLab