From 839055053b3032f7b0ba19ba7517cb308c15d081 Mon Sep 17 00:00:00 2001
From: Mark Driver <mdd31@cam.ac.uk>
Date: Sat, 12 Aug 2017 18:23:46 +0100
Subject: [PATCH] added functionality to plot using split polynomials.

---
 .../polynomialanalysis/polynomialplotting.py  | 48 ++++++++++++++++--
 .../polynomialplottingtest.py                 | 49 +++++++++++++++++++
 2 files changed, 93 insertions(+), 4 deletions(-)

diff --git a/solventmapcreator/polynomialanalysis/polynomialplotting.py b/solventmapcreator/polynomialanalysis/polynomialplotting.py
index 083144f..0a24f81 100644
--- a/solventmapcreator/polynomialanalysis/polynomialplotting.py
+++ b/solventmapcreator/polynomialanalysis/polynomialplotting.py
@@ -8,6 +8,8 @@ the polynomial fit data and also the plots.
 """
 
 import logging
+import numpy as np
+import solventmapcreator.polynomialanalysis.polynomialvaluecalculator as polyvalcalc
 import resultsanalysis.resultsoutput.plottinginput as plottinginput
 import solventmapcreator.io.polynomialdatawriter as polynomialdatawriter
 import solventmapcreator.io.solvationenergyreader as solvationenergyreader
@@ -16,21 +18,59 @@ logging.basicConfig()
 LOGGER = logging.getLogger(__name__)
 LOGGER.setLevel(logging.WARN)
 
-def plot_energies_from_file(free_energy_filename, polynomial_order):
+def plot_energies_from_file(free_energy_filename, polynomial_order, **kwargs):
     """This creates in the plotting information, and outputs a plot to file.
+    """        
+    plot_data = parse_energies_create_plot_input_data(free_energy_filename,
+                                                      polynomial_order,
+                                                      kwargs.get("pos_subset_list", None),
+                                                      kwargs.get("neg_subset_list", None))
+    if type(plot_data['polynomial coefficients']) is dict:
+        return create_scatter_with_split_poly(plot_data, **kwargs)
+    else:
+        return plottinginput.create_scatter_graph_with_polynomial(plot_data)
+
+def create_scatter_with_split_poly(input_data, **kwargs):
+    """This generates the split plot when a split polynomial is required.
     """
-    plot_data = parse_energies_create_plot_input_data(free_energy_filename, polynomial_order)
-    return plottinginput.create_scatter_graph_with_polynomial(plot_data)
+    scatter_plot, scatter_plot_axis = plottinginput.plot_scatter_graph(input_data)
+    x_data = np.sort(input_data['x_data'])
+    polynomial_values = polyvalcalc.calculate_polynomial_values(x_data,
+                                                                input_data["polynomial coefficients"])
+    plottinginput.plot_polynomial_curve(scatter_plot_axis, x_data,
+                                        polynomial_values,
+                                        input_data["polynomial order"])
+    fileformat = kwargs.get("fileformat", "eps")
+    output_filename = plottinginput.create_output_filename(input_data, fileformat)
+    plottinginput.plt.savefig(output_filename, format=fileformat)
+    plottinginput.plt.close(scatter_plot)
+    return 0
 
-def parse_energies_create_plot_input_data(free_energy_filename, polynomial_order):
+def parse_energies_create_plot_input_data(free_energy_filename, polynomial_order,
+                                          pos_subset_list=None, neg_subset_list=None):
     """This creates the input data for a plot. This overrides the default figure
     label, so that the figures can be distinguished based on the solvent.
     """
     datapoints = parse_free_energy_from_file_with_data_arrays(free_energy_filename)
     plot_data = datapoints.generateScatterPlotParameters(label_type='', polynomial_order=polynomial_order)
+    if pos_subset_list != None and neg_subset_list != None:
+        poly_coeff_dict = generate_poly_pos_neg_separate_fit(datapoints, polynomial_order,
+                                                             pos_subset_list, neg_subset_list)
+        plot_data['polynomial coefficients'] = poly_coeff_dict
     plot_data['figure_label'] = free_energy_filename.replace('.xml', '')
     return plot_data
 
+def generate_poly_pos_neg_separate_fit(datapoints, order, pos_subset_list, neg_subset_list):
+    """This generates the fit for the positive and negative datapoints separately.
+    A dictionary containing the 2 fits are returned.
+    """
+    poly_dict = polynomialdatawriter.generate_polynomial_fit_data_positive_negative_separate_fit(datapoints,
+                                                                                            order,
+                                                                                            pos_subset_list,
+                                                                                            neg_subset_list)
+    return {"positive":poly_dict["positive"]["coefficients"],
+            "negative":poly_dict["negative"]["coefficients"]}
+
 def parse_poly_data_to_file_split_fit(free_energy_filename, order_list, poly_filename,
                                       pos_subset_list, neg_subset_list):
     """This reads in the energy information, and then performs the fits on the
diff --git a/solventmapcreator/test/polynomialanalysistest/polynomialplottingtest.py b/solventmapcreator/test/polynomialanalysistest/polynomialplottingtest.py
index 1a2a210..70c65a4 100644
--- a/solventmapcreator/test/polynomialanalysistest/polynomialplottingtest.py
+++ b/solventmapcreator/test/polynomialanalysistest/polynomialplottingtest.py
@@ -73,6 +73,55 @@ class PolynomialPlottingTestCase(unittest.TestCase):
             else:
                 LOGGER.debug("assert equal string")
                 self.assertEqual(actual_dict[key], expected_dict[key])
+        #test to see if new functionality is supported
+        neg_set_list = ["-5.400solute", "-4.300solute", "-9.100solute", "-11.100solute", "-15.400solute"]
+        pos_set_list = ["0.500solute", "1.200solute", "7.200solute"]
+        actual_dict2 = polynomialplotting.parse_energies_create_plot_input_data("resources/water.xml", 4,
+                                                                                pos_set_list, neg_set_list)
+        self.assertListEqual(sorted(actual_dict2.keys()), sorted(expected_dict.keys()))
+        for key in actual_dict.keys():
+            LOGGER.debug("Key: %s", key)
+            if key == 'y_data' or key == 'x_data':
+                LOGGER.debug("assert equal array x and y data")
+                np.testing.assert_array_almost_equal(actual_dict2[key],
+                                                     expected_dict[key])
+            elif key == 'polynomial coefficients':
+                self.assertListEqual(["negative", "positive"],
+                                     sorted(actual_dict2['polynomial coefficients'].keys()))
+                negative_coeffs = np.array([5.038843e+00, 2.290628e+00, -9.855830e-02,
+                                            -6.480660e-03, -1.545673e-04])
+                positive_coeffs = np.array([1.351575, -1.317976, -0.126786,
+                                            -0.015156, -0.002038])
+                np.testing.assert_array_almost_equal(negative_coeffs,
+                                                     actual_dict2['polynomial coefficients']["negative"])
+                np.testing.assert_array_almost_equal(positive_coeffs,
+                                                     actual_dict2['polynomial coefficients']["positive"])
+            elif key == 'plot_axis_range' or key == 'x_axis_range' or key == 'y_axis_range':
+                np.testing.assert_array_almost_equal(np.array(actual_dict2[key]),
+                                                     np.array(expected_dict[key]))
+            else:
+                LOGGER.debug("assert equal string")
+                self.assertEqual(actual_dict2[key], expected_dict[key])
+    def test_generate_poly_pos_neg_separate_fit(self):
+        """Test to see if expected dict is returned.
+        """
+        
+        neg_set_list = ["-5.400solute", "-4.300solute", "-9.100solute", "-11.100solute", "-15.400solute"]
+        pos_set_list = ["0.500solute", "1.200solute", "7.200solute"]
+        negative_coeffs = np.array([5.038843e+00, 2.290628e+00, -9.855830e-02,
+                                    -6.480660e-03, -1.545673e-04])
+        positive_coeffs = np.array([1.351575, -1.317976, -0.126786, -0.015156,
+                                    -0.002038])
+        actual_dict = polynomialplotting.generate_poly_pos_neg_separate_fit(self.expected_datapoints, 4,
+                                                                            pos_set_list, neg_set_list)
+        LOGGER.debug("Actual Dict")
+        LOGGER.debug(actual_dict)
+        self.assertListEqual(["negative", "positive"],
+                             sorted(actual_dict.keys()))
+        np.testing.assert_array_almost_equal(negative_coeffs,
+                                             actual_dict["negative"])
+        np.testing.assert_array_almost_equal(positive_coeffs,
+                                             actual_dict["positive"])
     def test_parse_poly_data_to_file_split_fit(self):
         """Test to see if expected fit is done on the positive and negative subsets
         after the file was read in.
-- 
GitLab