From e296c1c9c767dbe9eec1cd79c663d0037c9d5408 Mon Sep 17 00:00:00 2001
From: Jake Smith <jws52@cam.ac.uk>
Date: Thu, 30 Jun 2022 14:47:35 +0100
Subject: [PATCH] feat: Env suit will not overwrite unless prompted

---
 EnvSuitPipeline.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/EnvSuitPipeline.py b/EnvSuitPipeline.py
index 191f6f4..666b95e 100644
--- a/EnvSuitPipeline.py
+++ b/EnvSuitPipeline.py
@@ -161,7 +161,11 @@ def run_merger(workPath):
 
 #######################################
 
-def run_pipeline(pipeline_config, region, dateString, extracted = False):
+def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent_overwrite = True):
+    '''
+    The prevent_overwrite parameter can be set to False if you want to re-run 
+    a job in-place.
+    '''
     # Get parameters from the config
     resourcesPath = getParameter(pipeline_config,'RESOURCES_PATH')
     workPath = getParameter(pipeline_config,'WORK_PATH') + 'ENVIRONMENT_2.0_' + dateString + '/'
@@ -182,9 +186,15 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False):
     template_configFile = resourcesPath + templateName
     config = loadConfig(template_configFile)
 
+    # Before writing any files, check the output path doesn't exist already
+    # We might expect outPath to exist already, but not the processed subfolder
+    region_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region)
+    if prevent_overwrite: assert not os.path.exists(region_outPath)
+
     # Get spatial points file for the region
     region_spatial_points_file = resourcesPath + 'input_spatial_points_' + region + '.csv'
     input_spatial_points_file = workPath + 'input_spatial_points.csv'
+    if prevent_overwrite: assert not os.path.exists(input_spatial_points_file)
     shutil.copy(region_spatial_points_file,input_spatial_points_file)
     spatial_points = pd.read_csv(input_spatial_points_file)
     spatial_dim = spatial_points.shape[0]
@@ -195,6 +205,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False):
     extraction_temporal_points_file = workPath + 'extraction_temporal_points.csv'
     try:
         logger.info(f"Generate extraction temporal points to: {extraction_temporal_points_file}")
+        if prevent_overwrite: assert not os.path.exists(extraction_temporal_points_file)
         generate_temporal_points(extraction_temporal_points_file, dateString, timeresolution, nDayExtraction)
     except:
         logger.exception(f"Some failure when generate {extraction_temporal_points_file}", exc_info=True)
@@ -203,6 +214,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False):
     output_temporal_points_file = workPath + 'output_temporal_points.csv'
     try:
         logger.info(f"Generate output temporal points to: {output_temporal_points_file}")
+        if prevent_overwrite: assert not os.path.exists(output_temporal_points_file)
         generate_temporal_points(output_temporal_points_file, dateString, timeresolution, nDayForecast)
     except:
         logger.exception(f"Some failure when generate {output_temporal_points_file}", exc_info=True)
@@ -283,7 +295,7 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False):
                 run_merger(envSuitPath)
 
                 resultFile = envSuitPath + 'RIE.csv'
-                strain_outPath = os.path.join(outPath,'ENVIRONMENT_2.0_'+dateString,'processed',region,strain)
+                strain_outPath = os.path.join(region_outPath,strain)
                 strain_outFile = strain_outPath + '/RIE_value.csv'
 
                 # Check results dimension
-- 
GitLab