#!/usr/bin/env python3
# -*- coding: utf-8 -*-
##********************************************************************************************************************************************************
##
##  This module plots the fit results together with the experimental data into a file
##  Copyright (C) 2009 - 2024  Thomas Moeller
##
##  I. Physikalisches Institut, University of Cologne
##
##
##
##  The following subroutines and functions are included in this module:
##
##      - subroutine histOutline:                       subroutine for plotting histograms in outline format
##      - subroutine plot:                              plot experimental data, fit function and values of chi^2
##
##
##
##  Versions of the program:
##
##  Who           When         What
##
##  T. Moeller    2009-07-09   initial version
##  T. Moeller    2012-01-16   improve documentation of source code
##  T. Moeller    2020-01-02   porting to python 3, minor improvements
##
##
##
##  License:
##
##    GNU GENERAL PUBLIC LICENSE
##    Version 3, 29 June 2007
##    (Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>)
##
##
##    This program is free software: you can redistribute it and/or modify
##    it under the terms of the GNU General Public License as published by
##    the Free Software Foundation, either version 3 of the License, or
##    (at your option) any later version.
##
##    This program is distributed in the hope that it will be useful,
##    but WITHOUT ANY WARRANTY; without even the implied warranty of
##    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##    GNU General Public License for more details.
##
##    You should have received a copy of the GNU General Public License
##    along with this program.  If not, see <http://www.gnu.org/licenses/>.
##
##********************************************************************************************************************************************************

"""

Plot data:
----------

This package displays the experimental data together with the calculated
fit-function using the matplotlib package.

"""


##******************************************************************** load packages *********************************************************************
from __future__ import print_function                                                       ## for python 2 usage
import numpy as np                                                                          ## import numpy package
import sys                                                                                  ## load package for system manipulation
if (not 'matplotlib' in sys.modules):
    import matplotlib                                                                       ## import matplotlib package
import pylab                                                                                ## load python package for plotting pylab
##--------------------------------------------------------------------------------------------------------------------------------------------------------


##--------------------------------------------------------------------------------------------------------------------------------------------------------
##
##  plot experimental data, fit function and values of chi^2
##
def plot(plot_option, exp_data_x, exp_data_y, exp_data_error, number_files, lengthexpdata, ColumnY, FitFunctionValues, Chi2Values, \
         fit_log, plotflag, FilenameExtension, ErrorModelFunctionValues, ErrorChi2Values, NumberSites, parameter_file_orig):
    """
    input parameter:    plot_option:                flag for choosing plotting format
                        exp_data_x:                 experimental data points
                        exp_data_y:                 observation data
                        exp_data_error:             error of observation data
                        number_files:               number of observation data files
                        lengthexpdata:              number of observation data points
                        ColumnY:                    number of observation data at each data point
                        FitFunctionValues:          values of the fit function
                        Chi2Values:                 values of chi^2
                        fit_log:                    path and name of log file
                        plotflag:                   flag for plotting
                        FilenameExtension:          filename extension
                        ErrorModelFunctionValues:   if the last algorithm is the error estim. algorithm, plot model func. for lower and upper error values
                        ErrorChi2Values:            if the last algorithm is the error estim. algorithm, plot chi2 func. for lower and upper error values
                        NumberSites:                number of sites
                        non modified name of the file including the starting values of each fit parameter

    output parameter:   None
    """

    # Debug:
    # print ("plot_option = ", plot_option)
    # print ("exp_data_x = ", exp_data_x)
    # print ("exp_data_y = ", exp_data_y)
    # print ("exp_data_error = ", exp_data_error)
    # print ("number_files = ", number_files)
    # print ("lengthexpdata = ", lengthexpdata)
    # print ("ColumnY = ", ColumnY)
    # print ("FitFunctionValues = ", FitFunctionValues)
    # print ("Chi2Values = ", Chi2Values)
    # print ("fit_log = ", fit_log)
    # print ("plotflag = ", plotflag)
    # print ("FilenameExtension = ", FilenameExtension)
    # print ("ErrorModelFunctionValues = ", ErrorModelFunctionValues)
    # print ("ErrorChi2Values = ", ErrorChi2Values)
    # print ("NumberSites = ", NumberSites)
    # print ("parameter_file_orig = ", parameter_file_orig)


    ## make a plot for each site
    for site in range(NumberSites):                                                         ## loop over all sites


        ## if input parameter ErrorModelFunctionValues == [], the error_flag is set to false, otherwise to true
        if (ErrorModelFunctionValues != []):
            error_flag = "true"
        else:
            error_flag = "false"


        ## set interactive flag
        if (plot_option[3] == "true"):
            pylab.ion()
        else:
            pylab.ioff()


        ## set dimension of figure (in inches)
        try:
            fig = pylab.figure(figsize=(15, number_files * 10))
        except:
            from matplotlib import pyplot as plt
            plt.switch_backend('agg')
            fig = pylab.figure(figsize=(15, number_files * 10))
        fig.clear()


        ## are all exp. data files 1D functions
        plot_1d_flag = "true"
        for NumFile in range(number_files):
            if (len(exp_data_x[NumFile][0]) != 1):
                plot_1d_flag = "false"
                break

        # Debug:
        # print ("parameter_file_orig = ", parameter_file_orig, '<<')


        ##------------------------------------------------------------------------------------------------------------------------------------------------
        ## construct array which are required for plotting
        xmin = 1e99
        xmax = 0
        PlotWasDone = "false"
        for NumFile in range(number_files):                                                 ## loop over all experimental data files
            exp_data = []
            fit_data = []
            upper_data = []
            lower_data = []
            upper_chi2_data = []
            lower_chi2_data = []
            diff_data = []
            exp_dataX = []
            i = (-1)
            for element in exp_data_x[NumFile]:                                             ## loop over all data point
                i += 1
                LineExp = []
                LineFit = []
                LineDiff = []
                LineUpper = []
                LineLower = []
                LineChi2Upper = []
                LineChi2Lower = []


                ## construct array containing model function values for each data point
                if (NumberSites == 1):
                    model_array = FitFunctionValues[NumFile][i]
                else:
                    model_array = FitFunctionValues[site][NumFile][i]
                for point in model_array:
                    if (np.isnan(point) or np.isinf(point) or abs(point) > 1.e20):
                        point = 0.0
                    LineFit.append(point)
                fit_data.append(LineFit)


                ## construct array containing chi2 values for each data point
                if (NumberSites == 1):
                    chi2_array = Chi2Values[NumFile][i]
                else:
                    chi2_array = Chi2Values[site][NumFile][i]
                for point in chi2_array:
                    if (np.isnan(point) or np.isinf(point)):
                        point = 0.0
                    LineDiff.append(point)
                diff_data.append(LineDiff)


                ## construct array containing y-points of experimental data
                for point in exp_data_y[NumFile][i]:
                    if (np.isnan(point) or np.isinf(point)):
                        point = 0.0
                    LineExp.append(point)
                exp_data.append(LineExp)


                ## construct array containing first x-point of experimental data (3d or higher dimensional data are ignored)
                exp_dataX.append(element[0])


                ## define lower and upper limit for the experimental data
                if (element[0] < xmin):
                    xmin = element[0]
                if (element[0] > xmax):
                    xmax = element[0]


                ## if the last algorithm is the error estimation algorithm, construct array containing model function for the upper and lower error values
                if (error_flag == "true"):


                    ## construct array for upper error values
                    for point in ErrorModelFunctionValues[0][0][NumFile][i]:
                        if (np.isnan(point) or np.isinf(point) or abs(point) > 1.e20):
                            point = 0.0
                        LineUpper.append(point)
                    upper_data.append(LineUpper)


                    ## construct array for lower error values
                    for point in ErrorModelFunctionValues[1][0][NumFile][i]:
                        if (np.isnan(point) or np.isinf(point) or abs(point) > 1.e20):
                            point = 0.0
                        LineLower.append(point)
                    lower_data.append(LineLower)


                    ## construct array for upper chi2 error values
                    for point in ErrorChi2Values[0][0][NumFile][i]:
                        if (np.isnan(point) or np.isinf(point) or abs(point) > 1.e20):
                            point = 0.0
                        LineChi2Upper.append(point)
                    upper_chi2_data.append(LineChi2Upper)


                    ## construct array for lower chi2 error values
                    for point in ErrorChi2Values[1][0][NumFile][i]:
                        if (np.isnan(point) or np.isinf(point) or abs(point) > 1.e20):
                            point = 0.0
                        LineChi2Lower.append(point)
                    lower_chi2_data.append(LineChi2Lower)



            ## check if dimension of data is less than 3, otherwise no plot is available
            if (len(exp_data_x[NumFile][0]) < 3):
                PlotWasDone = "true"

                # Debug:
                # print ("NumFile, number_files = ", NumFile, number_files)


                ## plot observation data and fit function together in one diagram
                if (number_files == 1):


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## plot observation data and fit function together in one diagram
                    pylab.subplots_adjust(hspace = 0.45, wspace = 0.2, bottom = 0.1, top = 0.98)
                    pl = pylab.subplot(1, 2, 1)
                    pl.plot(exp_dataX, exp_data, '-', color = 'black', linewidth = 1.0, label = 'data', drawstyle = 'steps-mid')

                    #    if (parameter_file_orig.endswith(".molfit")):
                    #        pl.plot(exp_dataX, exp_data, '-', color='black', linewidth=1.0, label='data', drawstyle='steps-mid')
                    #    else:
                    #        pl.plot(exp_dataX, exp_data, 'b.', label = 'data')                  ## add exp. data to left panel

                    pl.plot(exp_dataX, fit_data, 'g-', label = 'fit')                       ## add model function to left panel
                    if (error_flag == "true"):
                        pl.plot(exp_dataX, upper_data, 'c--', label = 'upper error values')
                        pl.plot(exp_dataX, lower_data, 'k-.', label = 'lower error values')
                    pl.grid(True)                                                           ## add grid to left panel
                    if (parameter_file_orig.endswith(".molfit")):
                        pl.xaxis.set_major_formatter(pylab.matplotlib.ticker.FormatStrFormatter('%.5e'))
                    pl.legend()                                                             ## add legend to left panel


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## define label of axis
                    pl.set_xlabel(r"%s" % plot_option[1])
                    pl.set_ylabel(r"%s" % plot_option[2])


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## set y-ticks of left panel
                    ymax_exp = max(exp_data)
                    ymax_fit = max(fit_data)
                    ymax = max(ymax_exp, ymax_fit)[0]                                       ## set max. y-tick
                    ymin_exp = min(exp_data)
                    ymin_fit = min(fit_data)
                    ymin = min(ymin_exp, ymin_fit)[0]                                       ## set min. y-tick
                    pl.set_ylim(ymin, ymax)


                    ## set y-tick
                    if (ymin > 0):                                                          ## if min. y-tick > 0 set to 0
                        ymin = 0
                    step = (ymax - ymin)/3                                                  ## define only 5 ticks
                    yticks = np.arange(ymin, ymax, step)
                    # pl.set_yticks(yticks)


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## set x-ticks of left panel
                    xminLocal = min(exp_dataX)
                    xmaxLocal = max(exp_dataX)
                    step = (xmaxLocal - xminLocal) / 2.1                                    ## define only 5 ticks
                    if (len(exp_dataX) > 1):
                        xticks = np.arange(xminLocal, xmaxLocal, step)
                        pl.set_xticks(xticks)

                    # Debug:
                    # print ("xminLocal = ", xminLocal)
                    # print ("xmaxLocal = ", xmaxLocal)
                    # print ("step = ", step)
                    # print ("xticks = ", xticks)


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## plot chi^2 as a function of data points to right panel
                    prv = pylab.subplot(1, 2, 2)
                    pr = prv.twinx()                                                        ## add second y-axis
                    pr.plot(exp_dataX, diff_data, 'r', label = 'chi^2')
                    ymax_chi2 = max(diff_data)[0]
                    ymin_chi2 = min(diff_data)[0]
                    pr.set_ylim(ymin_chi2, ymax_chi2)
                    prv.set_ylim(ymin_chi2, ymax_chi2)
                    if (error_flag == "true"):
                        pr.plot(exp_dataX, upper_chi2_data, 'c--', label = 'upper chi2 values')
                        pr.plot(exp_dataX, lower_chi2_data, 'k-.', label = 'lower chi2 values')
                    pr.set_xlim(xminLocal, xmaxLocal)
                    prv.set_xlim(xminLocal, xmaxLocal)
                    if (len(exp_dataX) > 1):
                        pr.set_xticks(xticks)
                        prv.set_xticks(xticks)
                    pr.grid(True)
                    if (parameter_file_orig.endswith(".molfit")):
                        pr.xaxis.set_major_formatter(pylab.matplotlib.ticker.FormatStrFormatter('%.5e'))
                    pr.legend()


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## define label of axis for right panel
                    prv.set_xlabel(r"%s" % plot_option[1])
                    pr.set_xlabel(r"%s" % plot_option[1])
                    prv.set_yticklabels([])                                                 ## remove tick labels for the left y-axis
                    pr.set_ylabel(r"$\chi^2$")                                              ## add y-axis label only to middle panel


                ##========================================================================================================================================
                ## make multiple plots
                elif (number_files > 1):            # and plot_1d_flag == "true"):


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## Left panel: construct plot with exp. data and model function
                    pylab.subplots_adjust(hspace = 0.45, wspace = 0.2, bottom = 0.02, top = 0.98)
                    pl = pylab.subplot(number_files, 2, 2 * (NumFile + 1) - 1)
                    if (parameter_file_orig.endswith(".molfit")):
                        pl.plot(exp_dataX, exp_data, '-', color = 'black', linewidth = 1.0, label = 'data', drawstyle = 'steps-mid')
                    else:
                        pl.plot(exp_dataX, exp_data, 'b.', label = 'data')                  ## add exp. data to left panel
                    pl.plot(exp_dataX, fit_data, 'g-', linewidth = 2, label = 'fit')        ## add model function to left panel
                    ymax_exp = max(exp_data)
                    ymax_fit = max(fit_data)
                    ymax = max(ymax_exp, ymax_fit)[0]                                       ## set max. y-tick
                    ymin_exp = min(exp_data)
                    ymin_fit = min(fit_data)
                    ymin = min(ymin_exp, ymin_fit)[0]                                       ## set min. y-tick
                    pl.set_ylim(ymin, ymax)
                    if (error_flag == "true"):
                        pl.plot(exp_dataX, upper_data, 'c--', label = 'upper error values')
                        pl.plot(exp_dataX, lower_data, 'k-.', label = 'lower error values')
                    pl.grid(True)                                                           ## add grid to left panel
                    if (NumFile == 0):
                        pl.legend()                                                         ## add legend to the first left panel
                    if (parameter_file_orig.endswith(".molfit")):
                        pl.xaxis.set_major_formatter(pylab.matplotlib.ticker.FormatStrFormatter('%.5e'))


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## define label of axis of left panel
                    pl.set_ylabel(r"%s" % plot_option[2])                                   ## add y-axis label only to middle panel
                    if ((NumFile + 1) == number_files):
                        pl.set_xlabel(r"%s" % plot_option[1])                               ## add x-axis label only to the lowest panel
                    else:
                        pl.set_xlabel("")                                                   ## remove tick labels for the left x-axis


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## set x-ticks of left panel
                    step = (xmax - xmin)/4                                                  ## define only 4 ticks
                    xtic = np.arange(xmin, xmax, step)
                    #pl.set_xticks(xtic)


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## set y-ticks of left panel
                    ymin, ymax = pylab.ylim()
                    step = (ymax - ymin)/2                                                  ## define only 3 ticks
                    ytic = np.arange(ymin, ymax, step)
                    #pl.set_yticks(ytic)


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## plot chi^2 as a function of data points to right panel
                    prv = pylab.subplot(number_files, 2, 2 * (NumFile + 1))
                    prv.plot(exp_dataX, diff_data, 'r', label = 'chi^2')                    ## add chi^2 "function" to right panel
                    ymax_chi2 = max(diff_data)[0]
                    ymin_chi2 = min(diff_data)[0]
                    prv.set_ylim(ymin_chi2, ymax_chi2)
                    if (error_flag == "true"):
                        prv.plot(exp_dataX, upper_chi2_data, 'c--', label = 'upper chi2 values')
                        prv.plot(exp_dataX, lower_chi2_data, 'k-.', label = 'lower chi2 values')
                    pr = prv.twinx()                                                        ## add second y-axis
                    prv.set_yticklabels([])                                                 ## remove y-ticks of right axis
                    prv.grid(True)                                                          ## add grid to right panel
                    if (NumFile == 0):
                        prv.legend()                                                        ## add legend to the first left panel


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## define label of axis of right panel
                    prv.set_ylabel("")                                                      ## remove y-axis label of left axis of right panel
                    pr.set_ylabel(r"$\chi^2$")                                              ## add y-axis label only to middle panel
                    if ((NumFile + 1) == number_files):
                        pr.set_xlabel(r"%s" % plot_option[1])
                        prv.set_xlabel(r"%s" % plot_option[1])
                    else:
                        pr.set_xlabel("")                                                   ## remove x-axis labels for the left x-axis
                        prv.set_xlabel("")                                                  ## remove x-axis label of left axis of right panel


                    ##------------------------------------------------------------------------------------------------------------------------------------
                    ## set x-ticks of right panel
                    step = (xmax - xmin)/4                                                  ## define only 5 ticks
                    xtic = np.arange(xmin, xmax, step)

                    ymin = 0
                    ymax = max(diff_data)[0]

                    # Debug:
                    # print ("ymin = ", ymin)
                    # print ("ymax = ", ymax)

                    pr.set_ylim((ymin, ymax))


        ##------------------------------------------------------------------------------------------------------------------------------------------------
        ## save the plot as a PNG image file (optional)
        if (plot_option[0] != "no" and PlotWasDone == "true"):
            pngFile = ""
            ii = fit_log.rfind("/")                                                         ## get path of experimental data
            if (ii != -1):
                pngFile = fit_log[:ii+1]
            pngFile += "final_plot." + FilenameExtension
            if (NumberSites > 1):
                ii = pngFile.rfind(".out")                                                  ## get path of experimental data
                if (ii != -1):
                    pngFile = pngFile[:ii]
                pngFile += "__site_" + str(site + 1) + ".out"


            ## show the plt plot window
            ## you can zoom the graph, drag the graph, change the margins, save the graph
            pylab.savefig(pngFile + ".png", dpi=100)
            if (plotflag == "true"):
                pylab.draw()
            elif (plotflag == "saveonly"):
                pylab.draw()


        ## close figure
        pylab.close(fig)


    ## we're done
    return
##--------------------------------------------------------------------------------------------------------------------------------------------------------
##--------------------------------------------------------------------------------------------------------------------------------------------------------
##--------------------------------------------------------------------------------------------------------------------------------------------------------

