From eee6e47fc81871d7c6218b598a61fbf14bb2fe04 Mon Sep 17 00:00:00 2001
From: lb584 <lb584@cam.ac.uk>
Date: Thu, 18 Aug 2022 15:58:25 +0100
Subject: [PATCH] more work to unify era5_met_extractor and ews_met_extractor

---
 EnvSuitPipeline.py | 39 +++++++++++++++++++--------------------
 1 file changed, 19 insertions(+), 20 deletions(-)

diff --git a/EnvSuitPipeline.py b/EnvSuitPipeline.py
index 49d24b2..c85d2f7 100644
--- a/EnvSuitPipeline.py
+++ b/EnvSuitPipeline.py
@@ -6,6 +6,7 @@ import shutil
 
 import pandas as pd
 
+from met_processing.common.params_file_parser import ParamsFileParser
 from met_processing.runner.common import job_runner
 
 MAX_WORKERS: int = 10
@@ -86,12 +87,9 @@ def generate_all(sys_config, run_config):
 
     run_configJson.close()
 
-    # Set slurm mode
-    slurm_mode = bool(getParameter(run_config,'SLURM'))
-
     # Run all generate
     try:
-        job_runner.generate_all_jobs(run_config, sys_config, slurm_mode)
+        job_runner.generate_all_jobs(run_config, sys_config)
     except Exception:
         logger.exception(f"Some failure when running one of the generate job", exc_info=True)
         raise
@@ -99,21 +97,21 @@ def generate_all(sys_config, run_config):
     return
 
 
-def run_extraction(work_path):
+def run_extraction(run_params: dict, sys_params: dict):
     logger.info(f"Running regridding in multi process mode.")
-    job_runner.run_extraction(work_path, **{"MAX_WORKERS": MAX_WORKERS})
+    job_runner.run_extraction(run_params, sys_params)
     logger.info('Data extracted and chunked')
 
 
-def run_post_processing(work_path):
+def run_post_processing(run_params: dict, sys_params: dict):
     logger.info(f"Running post-processing in multi process mode.")
-    job_runner.run_post_processing(work_path, **{"MAX_WORKERS": MAX_WORKERS})
+    job_runner.run_post_processing(run_params, sys_params)
     logger.info('Data extracted and chunked')
 
 
-def run_merger(work_path):
+def run_merger(run_params: dict, sys_params: dict):
     try:
-        job_runner.run_merge_post_processing(work_path)
+        job_runner.run_merge_post_processing(run_params, sys_params)
     except Exception:
         logger.exception(f"Some failure when running merge RIE", exc_info=True)
         raise
@@ -182,9 +180,11 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent
     temporal_dim = output_temporal_points.shape[0]
 
     # Modify run_config
-    config['TIMEPOINTS_FILE_PATH'] = extraction_temporal_points_file
-    config['OUTPUT_DIR'] = workPath
-    config['SPATIAL_POINTS_FILE_PATH'] = input_spatial_points_file
+    config[ParamsFileParser.TIMEPOINTS_FILE_KEY] = extraction_temporal_points_file
+    config[ParamsFileParser.OUTPUT_DIR_KEY] = workPath
+    config[ParamsFileParser.SPATIAL_POINTS_FILE_KEY] = input_spatial_points_file
+    # note that this field will only get used by the ewa5_met_data_extraction code, which uses a multi-processor module
+    config[ParamsFileParser.MAX_PROCESSORS_KEY] = MAX_WORKERS
 
     config['FIELD_NAME_CONSTANTS_PATH'] = getParameter(pipeline_config,'FIELD_NAME_CONSTANTS')
 
@@ -212,14 +212,13 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent
 
             # Extract
             if (extracted == False):
-                run_extraction(workPath)
+                run_extraction(config, sys_config)
                 extracted = True
 
             logger.info(f"Starting {processor_name} post processor ---------------------------------")
-            processorPath = workPath + 'post_processing/' + processor_name + '/'
-            run_post_processing(processorPath)
+            run_post_processing(config, sys_config)
 
-            run_merger(processorPath)
+            run_merger(config, sys_config)
         else:
             strains = getParameter(pipeline_config, 'STRAINS')
 
@@ -245,14 +244,14 @@ def run_pipeline(pipeline_config, region, dateString, extracted = False, prevent
 
                 # Extract
                 if (extracted == False):
-                    run_extraction(workPath)
+                    run_extraction(config, sys_config)
                     extracted = True
 
                 logger.info(f"Starting {strain} suitability ---------------------------------")
                 envSuitPath = workPath + 'post_processing/RIE/'
-                run_post_processing(envSuitPath)
+                run_post_processing(config, sys_config)
 
-                run_merger(envSuitPath)
+                run_merger(config, sys_config)
 
                 resultFile = envSuitPath + 'RIE.csv'
                 strain_outPath = os.path.join(region_outPath,strain)
-- 
GitLab