Source code for masci_tools.vis.plot_methods

# -*- coding: utf-8 -*-
###############################################################################
# Copyright (c), Forschungszentrum Jülich GmbH, IAS-1/PGI-1, Germany.         #
#                All rights reserved.                                         #
# This file is part of the Masci-tools package.                               #
# (Material science tools)                                                    #
#                                                                             #
# The code is hosted on GitHub at https://github.com/judftteam/masci-tools.   #
# For further information on the license, see the LICENSE.txt file.           #
# For further information please visit http://judft.de/.                      #
#                                                                             #
###############################################################################
"""
In this module are plot routines collected to create default plots out of certain
ouput nodes from certain workflows with matplot lib.

Comment: Do not use any aiida methods, otherwise the methods in here can become
tricky to use inside a virtual environment. Make the user extract thing out of
aiida objects before hand or write something on top. Since usually parameter nodes,
or files are plotted, parse a dict or filepath.

Each of the plot_methods can take keyword arguments to modify parameters of the plots
There are keywords that are handled by a special class for defaults. All other arguments
will be passed on to the matplotlib plotting calls

For the definition of the defaults refer to :py:class:`~masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`

"""
# TODO but allow to optional parse information for saving and title,
#  (that user can put pks or structure formulas in there)
# Write/export data to file for all methods

from .matplotlib_plotter import MatplotlibPlotter
from masci_tools.vis import ensure_plotter_consistency, NestedPlotParameters
import warnings
import copy
import typing

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from pprint import pprint
import pandas as pd

plot_params = MatplotlibPlotter()


[docs]def set_mpl_plot_defaults(**kwargs): """ Set defaults for matplotib backend according to the given keyword arguments Available defaults can be seen in :py:class:`~masci_tools.vis.matplotlib_plotter.MatplotlibPlotter` """ plot_params.set_defaults(**kwargs)
[docs]def reset_mpl_plot_defaults(): """ Reset the defaults for matplotib backend to the hardcoded defaults Available defaults can be seen in :py:class:`~masci_tools.vis.matplotlib_plotter.MatplotlibPlotter` """ plot_params.reset_defaults()
[docs]def show_mpl_plot_defaults(): """ Show the currently set defaults for matplotib backend to the hardcoded defaults Available defaults can be seen in :py:class:`~masci_tools.vis.matplotlib_plotter.MatplotlibPlotter` """ pprint(plot_params.get_dict())
[docs]def get_mpl_help(key): """ Print the decription of the given key in the matplotlib backend Available defaults can be seen in :py:class:`~masci_tools.vis.matplotlib_plotter.MatplotlibPlotter` """ plot_params.get_description(key)
############################################################################### ########################## general plot routines ############################## ###############################################################################
[docs]@ensure_plotter_consistency(plot_params) def single_scatterplot(xdata, ydata, *, xlabel='', ylabel='', title='', saveas='scatterplot', axis=None, xerr=None, yerr=None, area_curve=None, **kwargs): """ Create a standard scatter plot (this should be flexible enough) to do all the basic plots. :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param title: str, title of the figure :param saveas: str specifying the filename (without file format) :param axis: Axes object, if given the plot will be applied to this object :param xerr: optional data for errorbar in x-direction :param yerr: optional data for errorbar in y-direction :param area_curve: if an area plot is made this arguments defines the other enclosing line defaults to 0 Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib functions (`errorbar` or `fill_between`) """ #DEPRECATION WARNINGS if 'plotlabel' in kwargs: warnings.warn('Please use plot_label instead of plotlabel', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('plotlabel') if 'scale' in kwargs: scale = kwargs.get('scale') if isinstance(scale, list): warnings.warn("Please provide scale as dict in the form {'x': value, 'y': value2}", DeprecationWarning) scale_new = {} if scale[0] is not None: scale_new['x'] = scale[0] if scale[1] is not None: scale_new['y'] = scale[1] kwargs['scale'] = scale_new if 'limits' in kwargs: limits = kwargs.get('limits') if isinstance(limits, list): warnings.warn("Please provide limits as dict in the form {'x': value, 'y': value2}", DeprecationWarning) limits_new = {} if limits[0] is not None: limits_new['x'] = limits[0] if limits[1] is not None: limits_new['y'] = limits[1] kwargs['limits'] = limits_new plot_params.set_defaults(default_type='function', color='k', plot_label='scatterplot') kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) #ax.xaxis.set_major_formatter(DateFormatter("%b %y")) #if yerr or xerr: # p1 = ax.errorbar(xdata, ydata, linetyp, label=plotlabel, color=color, # linewidth=linewidth_g, markersize=markersize_g, yerr=yerr, xerr=xerr) #else: # p1 = ax.plot(xdata, ydata, linetyp, label=plotlabel, color=color, # linewidth=linewidth_g, markersize=markersize_g) # TODO customizable error bars fmt='o', ecolor='g', capthick=2, ... # there the if is prob better... plot_kwargs = plot_params.plot_kwargs() if area_curve is None: shift = 0 else: shift = area_curve if plot_params['area_plot']: linecolor = plot_kwargs.pop('area_linecolor', None) if plot_params['area_vertical']: result = ax.fill_betweenx(ydata, xdata, x2=shift, **plot_kwargs, **kwargs) else: result = ax.fill_between(xdata, ydata, y2=shift, **plot_kwargs, **kwargs) plot_kwargs.pop('alpha', None) plot_kwargs.pop('label', None) plot_kwargs.pop('color', None) if plot_params['area_enclosing_line']: if linecolor is None: linecolor = result.get_facecolor()[0] ax.errorbar(xdata, ydata, yerr=yerr, xerr=xerr, alpha=plot_params['plot_alpha'], color=linecolor, **plot_kwargs, **kwargs) else: ax.errorbar(xdata, ydata, yerr=yerr, xerr=xerr, **plot_kwargs, **kwargs) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def multiple_scatterplots(xdata, ydata, *, xlabel='', ylabel='', title='', saveas='mscatterplot', axis=None, xerr=None, yerr=None, area_curve=None, **kwargs): """ Create a standard scatter plot with multiple sets of data (this should be flexible enough) to do all the basic plots. :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param title: str, title of the figure :param saveas: str specifying the filename (without file format) :param axis: Axes object, if given the plot will be applied to this object :param xerr: optional data for errorbar in x-direction :param yerr: optional data for errorbar in y-direction :param area_curve: if an area plot is made this arguments defines the other enclosing line defaults to 0 Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib functions (`errorbar` or `fill_between`) """ nplots = len(ydata) if nplots != len(xdata): # todo check dimention not len, without moving to special datatype. print('ydata and xdata must have the same dimension') return if not isinstance(ydata[0], (list, np.ndarray, pd.Series)): xdata, ydata = [xdata], [ydata] plot_params.single_plot = False plot_params.num_plots = len(ydata) #DEPRECATION WARNINGS if 'plot_labels' in kwargs: warnings.warn('Please use plot_label instead of plot_labels', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('plot_labels') if 'colors' in kwargs: warnings.warn('Please use color instead of colors', DeprecationWarning) kwargs['color'] = kwargs.pop('colors') if 'legend_option' in kwargs: warnings.warn('Please use legend_options instead of legend_option', DeprecationWarning) kwargs['legend_options'] = kwargs.pop('legend_option') if 'scale' in kwargs: scale = kwargs.get('scale') if isinstance(scale, list): warnings.warn("Please provide scale as dict in the form {'x': value, 'y': value2}", DeprecationWarning) scale_new = {} if scale[0] is not None: scale_new['x'] = scale[0] if scale[1] is not None: scale_new['y'] = scale[1] kwargs['scale'] = scale_new if 'limits' in kwargs: limits = kwargs.get('limits') if isinstance(limits, list): warnings.warn("Please provide limits as dict in the form {'x': value, 'y': value2}", DeprecationWarning) limits_new = {} if limits[0] is not None: limits_new['x'] = limits[0] if limits[1] is not None: limits_new['y'] = limits[1] kwargs['limits'] = limits_new if 'xticks' in kwargs: xticks = kwargs.get('xticks') if isinstance(xticks[0], list): warnings.warn('Please provide xticks and xticklabels seperately as two lists', DeprecationWarning) kwargs['xticklabels'] = xticks[0] kwargs['xticks'] = xticks[1] kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) # TODO good checks for input and setting of internals before plotting # allow all arguments as value then use for all or as lists with the righ length. plot_kwargs = plot_params.plot_kwargs() colors = [] for indx, data in enumerate(zip(xdata, ydata, plot_kwargs)): x, y, plot_kw = data if plot_params['repeat_colors_after'] is not None: if indx >= plot_params['repeat_colors_after']: plot_kw['color'] = colors[indx % plot_params['repeat_colors_after']] if isinstance(yerr, list): try: yerrt = yerr[indx] except IndexError: yerrt = yerr[0] else: yerrt = yerr if isinstance(xerr, list): try: xerrt = xerr[indx] except IndexError: xerrt = xerr[0] else: xerrt = xerr if area_curve is not None: if isinstance(area_curve, list): try: shift = area_curve[indx] except IndexError: shift = area_curve[0] else: shift = area_curve else: shift = 0 if plot_params[('area_plot', indx)]: linecolor = plot_kw.pop('area_linecolor', None) if plot_params[('area_vertical', indx)]: result = ax.fill_betweenx(y, x, x2=shift, **plot_kw, **kwargs) else: result = ax.fill_between(x, y, y2=shift, **plot_kw, **kwargs) colors.append(result.get_facecolor()[0]) plot_kw.pop('alpha', None) plot_kw.pop('label', None) plot_kw.pop('color', None) if plot_params[('area_enclosing_line', indx)]: if linecolor is None: linecolor = result.get_facecolor()[0] ax.errorbar(x, y, yerr=yerrt, xerr=xerrt, alpha=plot_params[('plot_alpha', indx)], color=linecolor, **plot_kw, **kwargs) else: result = ax.errorbar(x, y, yerr=yerrt, xerr=xerrt, **plot_kw, **kwargs) colors.append(result.lines[0].get_color()) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def multi_scatter_plot(xdata, ydata, *, size_data=None, color_data=None, xlabel='', ylabel='', title='', saveas='mscatterplot', axis=None, **kwargs): """ Create a scatter plot with varying marker size Info: x, y, size and color data must have the same dimensions. :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param size_data: arraylike, data for the markersizes (optional) :param color_data: arraylike, data for the color values with a colormap (optional) :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param title: str, title of the figure :param saveas: str specifying the filename (without file format) :param axis: Axes object, if given the plot will be applied to this object :param xerr: optional data for errorbar in x-direction :param yerr: optional data for errorbar in y-direction Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `scatter` """ nplots = len(ydata) if nplots != len(xdata): # todo check dimention not len, without moving to special datatype. print('ydata and xdata must have the same dimension') return if not isinstance(ydata[0], (list, np.ndarray, pd.Series)): xdata, ydata, size_data, color_data = [xdata], [ydata], [size_data], [color_data] plot_params.single_plot = False plot_params.num_plots = len(ydata) #DEPRECATION WARNINGS: label/plot_labels, alpha, limits, scale, legend_option, xticks if 'label' in kwargs: warnings.warn('Please use plot_label instead of label', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('label') if 'alpha' in kwargs: warnings.warn('Please use plot_alpha instead of alpha', DeprecationWarning) kwargs['plot_alpha'] = kwargs.pop('alpha') if 'legend_option' in kwargs: warnings.warn('Please use legend_options instead of legend_option', DeprecationWarning) kwargs['legend_options'] = kwargs.pop('legend_option') if 'scale' in kwargs: scale = kwargs.get('scale') if isinstance(scale, list): warnings.warn("Please provide scale as dict in the form {'x': value, 'y': value2}", DeprecationWarning) scale_new = {} if scale[0] is not None: scale_new['x'] = scale[0] if scale[1] is not None: scale_new['y'] = scale[1] kwargs['scale'] = scale_new if 'limits' in kwargs: limits = kwargs.get('limits') if isinstance(limits, list): warnings.warn("Please provide limits as dict in the form {'x': value, 'y': value2}", DeprecationWarning) limits_new = {} if limits[0] is not None: limits_new['x'] = limits[0] if limits[1] is not None: limits_new['y'] = limits[1] kwargs['limits'] = limits_new if 'xticks' in kwargs: xticks = kwargs.get('xticks') if isinstance(xticks[0], list): warnings.warn('Please provide xticks and xticklabels seperately as two lists', DeprecationWarning) kwargs['xticklabels'] = xticks[0] kwargs['xticks'] = xticks[1] plot_params.set_defaults(default_type='function', linestyle=None, area_plot=False, colorbar=False) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) plot_kwargs = plot_params.plot_kwargs(ignore='markersize', extra_keys={'cmap'}) legend_elements = [] legend_labels = [] if size_data is None: size_data = [None] * plot_params.num_plots if color_data is None: color_data = [None] * plot_params.num_plots for indx, data in enumerate(zip(xdata, ydata, size_data, color_data, plot_kwargs)): x, y, size, color, plot_kw = data if size is None: size = plot_params['markersize'] if color is not None: plot_kw.pop('color') res = ax.scatter(x, y=y, s=size, c=color, **plot_kw, **kwargs) if plot_kw.get('label', None) is not None and color is not None: if isinstance(color, (list, np.ndarray, pd.Series)): if not isinstance(color[0], str): legend_elements.append(res.legend_elements(num=1)[0][0]) legend_labels.append(plot_kw['label']) if any(c is not None for c in color_data): legend_elements = (legend_elements, legend_labels) else: legend_elements = None plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax, leg_elems=legend_elements) plot_params.show_colorbar(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def colormesh_plot(xdata, ydata, cdata, *, xlabel='', ylabel='', title='', saveas='colormesh', axis=None, **kwargs): """ Create plot with pcolormesh :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param cdata: arraylike, data for the color values with a colormap :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param title: str, title of the figure :param saveas: str specifying the filename (without file format) :param axis: Axes object, if given the plot will be applied to this object Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `pcolormesh` """ #Set default limits (not setting them leaves empty border) limits = kwargs.pop('limits', {}) if 'x' not in limits: limits['x'] = (xdata.min(), xdata.max()) if 'y' not in limits: limits['y'] = (ydata.min(), ydata.max()) kwargs['limits'] = limits plot_params.set_defaults(default_type='function', edgecolor='face') kwargs = plot_params.set_parameters(continue_on_error=True, area_plot=False, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) plot_kwargs = plot_params.plot_kwargs(plot_type='colormesh') ax.pcolormesh(xdata, ydata, cdata, **plot_kwargs, **kwargs) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.show_legend(ax) plot_params.show_colorbar(ax) plot_params.draw_lines(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def waterfall_plot(xdata, ydata, zdata, *, xlabel='', ylabel='', zlabel='', title='', saveas='waterfallplot', axis=None, **kwargs): """ Create a standard waterfall plot :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param zdata: arraylike, data for the z coordinate :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param zlabel: str, label written on the z axis :param title: str, title of the figure :param axis: Axes object, if given the plot will be applied to this object :param saveas: str specifying the filename (without file format) Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `scatter3D` """ nplots = len(ydata) if nplots != len(xdata): # todo check dimention not len, without moving to special datatype. print('ydata and xdata must have the same dimension') return if nplots != len(zdata): # todo check dimention not len, without moving to special datatype. print('ydata and zdata must have the same dimension') return if isinstance(zdata, np.ndarray): zmin = zdata.min() zmax = zdata.max() else: zmin = min(zdata) zmax = max(zdata) clim = None if 'limits' in kwargs: clim = kwargs['limits'].get('color', None) else: kwargs['limits'] = {} if clim is None: clim = (kwargs.get('vmin', zmin), kwargs.get('vmax', zmax)) kwargs['limits']['color'] = clim if not isinstance(ydata[0], (list, np.ndarray, pd.Series)): xdata, ydata, zdata = [xdata], [ydata], [zdata] plot_params.single_plot = False plot_params.num_plots = len(ydata) plot_params.set_defaults(default_type='function', markersize=30, linewidth=0, area_plot=False) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, zlabel=zlabel, axis=axis, projection='3d') plot_kwargs = plot_params.plot_kwargs(ignore=['markersize'], extra_keys={'cmap'}) for indx, data in enumerate(zip(xdata, ydata, zdata, plot_kwargs)): x, y, z, plot_kw = data ax.scatter3D(x, y, z, c=z, s=plot_params[('markersize', indx)], **plot_kw, **kwargs) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.show_legend(ax) plot_params.show_colorbar(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def surface_plot(xdata, ydata, zdata, *, xlabel='', ylabel='', zlabel='', title='', saveas='surface_plot', axis=None, **kwargs): """ Create a standard surface plot :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param zdata: arraylike, data for the z coordinate :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param zlabel: str, label written on the z axis :param title: str, title of the figure :param axis: Axes object, if given the plot will be applied to this object :param saveas: str specifying the filename (without file format) Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `plot_surface` """ nplots = len(ydata) if nplots != len(xdata): # todo check dimention not len, without moving to special datatype. print('ydata and xdata must have the same dimension') return if nplots != len(zdata): # todo check dimention not len, without moving to special datatype. print('ydata and zdata must have the same dimension') return if isinstance(zdata, np.ndarray): zmin = zdata.min() zmax = zdata.max() else: zmin = min(zdata) zmax = max(zdata) clim = None if 'limits' in kwargs: clim = kwargs['limits'].get('color', None) else: kwargs['limits'] = {} if clim is None: clim = (kwargs.get('vmin', zmin), kwargs.get('vmax', zmax)) kwargs['limits']['color'] = clim plot_params.set_defaults(default_type='function', linewidth=0, area_plot=False) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, zlabel=zlabel, axis=axis, projection='3d') plot_kwargs = plot_params.plot_kwargs(ignore=['markersize', 'marker'], extra_keys={'cmap'}) ax.plot_surface(xdata, ydata, zdata, **plot_kwargs, **kwargs) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.show_legend(ax) plot_params.show_colorbar(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def multiplot_moved(xdata, ydata, *, xlabel='', ylabel='', title='', scale_move=1.0, min_add=0, saveas='mscatterplot', **kwargs): """ Plots all the scatter plots above each other. It adds an arbitrary offset to the ydata to do this and calls `multiple_scatterplots`. Therefore you might not want to show the yaxis ticks :param xdata: arraylike, data for the x coordinate :param ydata: arraylike, data for the y coordinate :param xlabel: str, label written on the x axis :param ylabel: str, label written on the y axis :param title: str, title of the figure :param scale_move: float, max*scale_move determines size of the shift :param min_add: float, minimum shift :param saveas: str specifying the filename (without file format) Kwargs are passed on to the :py:func:`multiple_scatterplots()` call """ if 'yticks' not in kwargs: kwargs['yticks'] = [] if 'yticklabels' not in kwargs: kwargs['yticklabels'] = [] ydatanew = [] shifts = [] ymax = 0 for data in ydata: ydatanew.append(np.array(data) + ymax) shifts.append(ymax) ymax = ymax + max(data) * scale_move + min_add ax = multiple_scatterplots(xdata, ydatanew, xlabel=xlabel, ylabel=ylabel, title=title, saveas=saveas, area_curve=shifts, **kwargs) return ax
[docs]@ensure_plotter_consistency(plot_params) def histogram(xdata, density=False, histtype='bar', align='mid', orientation='vertical', log=False, axis=None, title='hist', xlabel='bins', ylabel='counts', saveas='histogram', return_hist_output=False, **kwargs): """ Create a standard looking histogram :param xdata: arraylike, Data for the histogram :param density: bool, if True the histogram is normed and a normal distribution is plotted with the same mu and sigma as the data :param histtype: str, type of the histogram :param align: str, defines where the bars for the bins are aligned :param orientation: str, is the histogram vertical or horizontal :param log: bool, if True a logarithmic scale is used for the counts :param axis: Axes object where to add the plot :param title: str, Title of the plot :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param saveas: str, filename for the saved plot :param return_hist_output: bool, if True the data output from hist will be returned Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `hist` """ if 'label' in kwargs: warnings.warn('Please use plot_label instead of label', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('label') if 'legend_option' in kwargs: warnings.warn('Please use legend_options instead of legend_option', DeprecationWarning) kwargs['legend_options'] = kwargs.pop('legend_option') if 'limits' in kwargs: limits = kwargs.get('limits') if isinstance(limits, list): warnings.warn("Please provide limits as dict in the form {'x': value, 'y': value2}", DeprecationWarning) limits_new = {} if limits[0] is not None: limits_new['x'] = limits[0] if limits[1] is not None: limits_new['y'] = limits[1] kwargs['limits'] = limits_new kwargs = plot_params.set_parameters(continue_on_error=True, set_powerlimits=not log, area_plot=False, **kwargs) if orientation == 'horizontal': if xlabel == 'bins' and ylabel == 'counts': xlabel, ylabel = ylabel, xlabel ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis, minor=True) plot_kwargs = plot_params.plot_kwargs(plot_type='histogram') n, bins, patches = ax.hist(xdata, density=density, histtype=histtype, align=align, orientation=orientation, log=log, **plot_kwargs, **kwargs) if density: mu = np.mean(xdata) sigma = np.std(xdata) y = norm.pdf(bins, mu, sigma) if orientation == 'horizontal': ax.plot(y, bins, '--') else: ax.plot(bins, y, '--') plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax) plot_params.save_plot(saveas) if return_hist_output: return ax, n, bins, patches else: return ax
# todo remove default histogramm, replace it in all code by histogramm
[docs]def default_histogram(*args, **kwargs): """ Create a standard looking histogram (DEPRECATED) """ warnings.warn('Use histogram instead of default_histogram', DeprecationWarning) res = histogram(*args, **kwargs) return res
[docs]@ensure_plotter_consistency(plot_params) def barchart(xdata, ydata, *, width=0.35, xlabel='x', ylabel='y', title='', bottom=None, saveas='barchart', axis=None, xerr=None, yerr=None, **kwargs): """ Create a standard bar chart plot (this should be flexible enough) to do all the basic bar chart plots. :param xdata: arraylike data for the x coordinates of the bars :param ydata: arraylike data for the heights of the bars :param width: float determines the width of the bars :param axis: Axes object where to add the plot :param title: str, Title of the plot :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param saveas: str, filename for the saved plot :param xerr: optional data for errorbar in x-direction :param yerr: optional data for errorbar in y-direction :param bottom: bottom values for the lowest end of the bars Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib function `bar` TODO: grouped barchart (meaing not stacked) """ nplots = len(ydata) if nplots != len(xdata): # todo check dimention not len, without moving to special datatype. print('ydata and xdata must have the same dimension') return if not isinstance(ydata[0], (list, np.ndarray, pd.Series)): xdata, ydata = [xdata], [ydata] plot_params.single_plot = False plot_params.num_plots = len(ydata) #DEPRECATION WARNINGS if 'plot_labels' in kwargs: warnings.warn('Please use plot_label instead of plot_labels', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('plot_labels') if 'colors' in kwargs: warnings.warn('Please use color instead of colors', DeprecationWarning) kwargs['color'] = kwargs.pop('colors') if 'legend_option' in kwargs: warnings.warn('Please use legend_options instead of legend_option', DeprecationWarning) kwargs['legend_options'] = kwargs.pop('legend_option') if 'scale' in kwargs: scale = kwargs.get('scale') if isinstance(scale, list): warnings.warn("Please provide scale as dict in the form {'x': value, 'y': value2}", DeprecationWarning) scale_new = {} if scale[0] is not None: scale_new['x'] = scale[0] if scale[1] is not None: scale_new['y'] = scale[1] kwargs['scale'] = scale_new if 'limits' in kwargs: limits = kwargs.get('limits') if isinstance(limits, list): warnings.warn("Please provide limits as dict in the form {'x': value, 'y': value2}", DeprecationWarning) limits_new = {} if limits[0] is not None: limits_new['x'] = limits[0] if limits[1] is not None: limits_new['y'] = limits[1] kwargs['limits'] = limits_new if 'xticks' in kwargs: xticks = kwargs.get('xticks') if isinstance(xticks[0], list): warnings.warn('Please provide xticks and xticklabels seperately as two lists', DeprecationWarning) kwargs['xticklabels'] = xticks[0] kwargs['xticks'] = xticks[1] plot_params.set_defaults(default_type='function', linewidth=None) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) # TODO good checks for input and setting of internals before plotting # allow all arguments as value then use for all or as lists with the righ length. if bottom: datab = bottom else: datab = np.zeros(len(ydata[0])) plot_kwargs = plot_params.plot_kwargs(plot_type='histogram') for indx, data in enumerate(zip(xdata, ydata, plot_kwargs)): x, y, plot_kw = data if isinstance(yerr, list): try: yerrt = yerr[indx] except KeyError: yerrt = yerr[0] else: yerrt = yerr if isinstance(xerr, list): try: xerrt = xerr[indx] except KeyError: xerrt = xerr[0] else: xerrt = xerr ax.bar(x, y, width, bottom=datab, **plot_kw, **kwargs) datab = datab + np.array(y) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def multiaxis_scatterplot(xdata, ydata, *, axes_loc, xlabel='', ylabel='', title='', num_cols=1, num_rows=1, saveas='mscatterplot', **kwargs): """ Create a scatter plot with multiple axes. :param xdata: list of arraylikes, passed on to the plotting functions for each axis (x-axis) :param ydata: list of arraylikes, passed on to the plotting functions for each axis (y-axis) :param axes_loc: list of tuples of two integers, location of each axis :param xlabel: str or list of str, labels for the x axis :param ylabel: str or list of str, labels for the y-axis :param title: str or list of str, titles for the subplots :param num_rows: int, how many rows of axis are created :param num_cols: int, how many columns of axis are created :param saveas: str filename of the saved file Special Kwargs: :param subplot_params: dict with integer keys, can contain all valid kwargs for :py:func:`multiple_scatterplots()` with the integer key denoting to which subplot the changes are applied :param axes_kwargs: dict with integer keys, additional arguments to pass on to `subplot2grid` for the creation of each axis (e.g colspan, rowspan) Other Kwargs will be passed on to all :py:func:`multiple_scatterplots()` calls (If they are not overwritten by parameters in `subplot_params`). """ #convert parameters to list of parameters for subplots subplot_params = kwargs.pop('subplot_params', {}) axes_kwargs = kwargs.pop('axes_kwargs', {}) param_list = [None] * len(axes_loc) for indx, val in enumerate(param_list): if indx in subplot_params: param_list[indx] = subplot_params[indx] else: param_list[indx] = {} if indx in axes_kwargs: param_list[indx]['axes_kwargs'] = axes_kwargs[indx] if not isinstance(xlabel, list): param_list[indx]['xlabel'] = xlabel else: param_list[indx]['xlabel'] = xlabel[indx] if not isinstance(ylabel, list): param_list[indx]['ylabel'] = ylabel else: param_list[indx]['ylabel'] = ylabel[indx] if not isinstance(title, list): param_list[indx]['title'] = title else: param_list[indx]['title'] = title[indx] general_keys = {'figure_kwargs', 'show', 'save_plots'} general_info = {key: val for key, val in kwargs.items() if key in general_keys} kwargs = {key: val for key, val in kwargs.items() if key not in general_keys} plot_params.set_parameters(**general_info) #figsize is automatically scaled with the shape of the plot plot_shape = (num_cols, num_rows) plot_params['figure_kwargs'] = { 'figsize': ([plot_shape[indx] * size for indx, size in enumerate(plot_params['figure_kwargs']['figsize'])]) } plot_shape = tuple(reversed(plot_shape)) fig = plt.figure(**plot_params['figure_kwargs']) axis = [] for indx, subplot_data in enumerate(zip(axes_loc, xdata, ydata, param_list)): location, x, y, params = subplot_data subplot_kwargs = copy.deepcopy(kwargs) subplot_kwargs.update(params) ax = plt.subplot2grid(plot_shape, location, fig=fig, **subplot_kwargs.pop('axes_kwargs', {})) with NestedPlotParameters(plot_params): ax = multiple_scatterplots(x, y, axis=ax, **subplot_kwargs, save_plots=False, show=False) axis.append(ax) plot_params.save_plot(saveas) return axis
############################################################################### ########################## special plot routines ############################## ###############################################################################
[docs]@ensure_plotter_consistency(plot_params) def plot_convex_hull2d(hull, *, title='Convex Hull', xlabel='Atomic Procentage', ylabel='Formation energy / atom [eV]', saveas='convex_hull', axis=None, **kwargs): """ Plot method for a 2d convex hull diagramm :param hull: pyhull.Convexhull #scipy.spatial.ConvexHull :param axis: Axes object where to add the plot :param title: str, Title of the plot :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param saveas: str, filename for the saved plot Function specific parameters: :param marker_hull: defaults to `marker`, marker type for the hull plot :param markersize_hull: defaults to `markersize`, markersize for the hull plot :param color_hull: defaults to `color`, color for the hull plot Kwargs will be passed on to :py:class:`masci_tools.vis.matplotlib_plotter.MatplotlibPlotter`. If the arguments are not recognized they are passed on to the matplotlib functions `plot` """ #DEPRECATE: color_line if 'color_line' in kwargs: warnings.warn('Please use color instead of color_line', DeprecationWarning) kwargs['color'] = kwargs.pop('colors') #Define function wide custom parameters plot_params.add_parameter('marker_hull', default_from='marker') plot_params.add_parameter('markersize_hull', default_from='markersize') plot_params.add_parameter('color_hull', default_from='color') kwargs = plot_params.set_parameters(continue_on_error=True, set_powerlimits=False, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) points = hull.points plot_kw = plot_params.plot_kwargs() plot_hull_kw = plot_params.plot_kwargs(marker='marker_hull', markersize='markersize_hull', color='color_hull') plot_hull_kw['linestyle'] = '' linestyle = plot_kw['linestyle'] plot_kw['linestyle'] = '' ax.plot(points[:, 0], points[:, 1], **plot_kw, **kwargs) for simplex in hull.simplices: # TODO leave out some lines, the ones about [0,0 -1,0] data = simplex.coords ax.plot(data[:, 0], data[:, 1], linestyle=linestyle, **plot_kw, **kwargs) ax.plot(data[:, 0], data[:, 1], **plot_hull_kw, **kwargs) # this section is from scipy.spatial.Convexhull interface #ax.plot(points[simplex, 0], points[simplex, 1], linestyle=linestyle, # color=color_line, linewidth=linewidth, markersize=markersize, **kwargs) #ax.plot(points[simplex, 0], points[simplex, 1], linestyle='', # color=color, markersize=markersize_hull, marker=marker_hull, **kwargs) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def plot_residuen(xdata, fitdata, realdata, *, errors=None, xlabel=r'Energy [eV]', ylabel=r'cts/s [arb]', title=r'Residuen', saveas='residuen', hist=True, return_residuen_data=True, **kwargs): """ Calculates and plots the residuen for given xdata fit results and the real data. If hist=True also the normed residual distribution is ploted with a normal distribution. :param xdata: arraylike data for the x-coordinate :param fitdata: arraylike fitted data for the y-coordinate :param realdata: arraylike data to plot residuen against the fit :param errors: dict, can be used to provide errordata for the x and y direction :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param title: str, title for the plot :param saveas: str, filename for the saved plot :param hist: bool, if True a normed residual distribution is ploted with a normal distribution. :param return_residuen_data: bool, if True in addition to the produced axis object also the residuen data is returned Special Kwargs: :param hist_kwargs: dict, these arguments will be passed on to the :py:func:`histogram()` call (if hist=True) Other Kwargs will be passed on to all :py:func:`single_scatterplot()` call """ if errors is None: errors = {} ydata = realdata - fitdata hist_kwargs = kwargs.pop('hist_kwargs', {}) general_keys = {'figure_kwargs', 'show', 'save_plots'} general_info = {key: val for key, val in kwargs.items() if key in general_keys} kwargs = {key: val for key, val in kwargs.items() if key not in general_keys} plot_params.set_parameters(**general_info) if hist: figsize = plot_params['figure_kwargs']['figsize'] #figsize is automatically scaled with the shape of the plot plot_params['figure_kwargs'] = {'figsize': (figsize[0] * 2, figsize[1])} plt.figure(**plot_params['figure_kwargs']) if hist: ax1 = plt.subplot2grid((1, 2), (0, 0)) ax2 = plt.subplot2grid((1, 2), (0, 1), sharey=ax1) axes = [ax1, ax2] else: ax1 = plt.subplot2grid((1, 1), (0, 0)) axes = ax1 with NestedPlotParameters(plot_params): ax1 = single_scatterplot(xdata, ydata, xlabel=xlabel, ylabel=ylabel, title=title, axis=ax1, show=False, save_plots=False, xerr=errors.get('x', None), yerr=errors.get('y', None), **kwargs) if hist: with NestedPlotParameters(plot_params): ax2 = histogram(ydata, bins=20, axis=ax2, orientation='horizontal', title='Residuen distribution', density=True, show=False, save_plots=False, **hist_kwargs) plot_params.save_plot(saveas) if return_residuen_data: return axes, ydata else: return axes
[docs]@ensure_plotter_consistency(plot_params) def plot_convergence_results(iteration, distance, total_energy, *, saveas1='t_energy_convergence', axis1=None, saveas2='distance_convergence', axis2=None, **kwargs): """ Plot the total energy versus the scf iteration and plot the distance of the density versus iterations. :param iteration: array for the number of iterations :param distance: array of distances :param total_energy: array of total energies :param saveas1: str, filename for the energy convergence plot :param axis1: Axes object for the energy convergence plot :param saveas2: str, filename for the distance plot :param axis2: Axes object for the distance plot Other Kwargs will be passed on to all :py:func:`single_scatterplot()` calls """ xlabel = r'Iteration' ylabel1 = r'Total energy difference [Htr]' ylabel2 = r'Distance [me/bohr^3]' title1 = r'Total energy difference over scf-Iterations' #title2 = r'Distance over scf-Iterations' title2 = r'Convergence (log)' # since we make a log plot of the total_energy make sure to plot the absolute total energy total_energy_abs_diff = [] for en0, en1 in zip(total_energy[:-1], total_energy[1:]): total_energy_abs_diff.append(abs(en1 - en0)) #saveas3 ='t_energy_convergence2' p1 = single_scatterplot(iteration[1:], total_energy_abs_diff, xlabel=xlabel, ylabel=ylabel1, title=title1, plot_label='delta total energy', saveas=saveas1, scale={'y': 'log'}, axis=axis1, **kwargs) #single_scatterplot(total_energy, iteration, xlabel, ylabel1, title1, plotlabel='total energy', saveas=saveas3) p2 = single_scatterplot(iteration, distance, xlabel=xlabel, ylabel=ylabel2, title=title2, plot_label='distance', saveas=saveas2, scale={'y': 'log'}, axis=axis2, **kwargs) return p1, p2
[docs]@ensure_plotter_consistency(plot_params) def plot_convergence_results_m(iterations, distances, total_energies, *, modes, nodes=None, saveas1='t_energy_convergence', saveas2='distance_convergence', axis1=None, axis2=None, **kwargs): """ Plot the total energy versus the scf iteration and plot the distance of the density versus iterations. :param iterations: array for the number of iterations :param distances: array of distances :param total_energies: array of total energies :param modes: list of convergence modes (if 'force' is in the list the last distance is removed) :param saveas1: str, filename for the energy convergence plot :param axis1: Axes object for the energy convergence plot :param saveas2: str, filename for the distance plot :param axis2: Axes object for the distance plot Other Kwargs will be passed on to all :py:func:`multiple_scatterplots()` calls """ xlabel = r'Iteration' ylabel1 = r'Total energy difference [Htr]' ylabel2 = r'Distance [me/bohr^3]' title1 = r'Total energy difference over scf-Iterations' #title2 = r'Distance over scf-Iterations' title2 = r'Convergence (log)' if 'plot_labels' in kwargs: warnings.warn('Please use plot_label instead of plot_labels', DeprecationWarning) kwargs['plot_label'] = kwargs.pop('plot_labels') iterations1 = [] plot_labels1 = [] plot_labels2 = [] # since we make a log plot of the total_energy make sure to plot the absolute total energy total_energy_abs_diffs = [] for i, total_energy in enumerate(total_energies): iterations1.append(iterations[i][1:]) total_energy_abs_diff = [] for en0, en1 in zip(total_energy[:-1], total_energy[1:]): total_energy_abs_diff.append(abs(en1 - en0)) total_energy_abs_diffs.append(total_energy_abs_diff) plot_labels1.append(f'delta total energy {i}') plot_labels2.append(f'distance {i}') #saveas3 ='t_energy_convergence2' if 'plot_label' in kwargs: plot_label = plot_params.convert_to_complete_list(kwargs.pop('plot_label'), single_plot=False, num_plots=len(plot_labels1), key='plot_label') plot_labels1 = [label if label is not None else plot_labels1[indx] for indx, label in enumerate(plot_label)] plot_labels2 = [label if label is not None else plot_labels2[indx] for indx, label in enumerate(plot_label)] p1 = multiple_scatterplots(iterations1, total_energy_abs_diffs, xlabel=xlabel, ylabel=ylabel1, title=title1, plot_label=plot_labels1, saveas=saveas1, scale={'y': 'log'}, axis=axis1, **kwargs) for i, mode in enumerate(modes): if mode == 'force': iterations[i].pop() print('Drop the last iteration because there was no charge distance, mode=force') p2 = multiple_scatterplots(iterations, distances, xlabel=xlabel, ylabel=ylabel2, title=title2, plot_label=plot_labels2, saveas=saveas2, scale={'y': 'log'}, axis=axis2, **kwargs) return p1, p2
[docs]@ensure_plotter_consistency(plot_params) def plot_lattice_constant(scaling, total_energy, *, fit_y=None, relative=True, ref_const=None, multi=False, title=r'Equation of states', saveas='lattice_constant', axis=None, **kwargs): """ Plot a lattice constant versus Total energy Plot also the fit. On the x axis is the scaling, it :param scaling: arraylike, data for the scaling factor :param total_energy: arraylike, data for the total energy :param fit_y: arraylike, optional data of fitted data :param relative: bool, scaling factor given (True), or lattice constants given? :param ref_const: float (optional), or list of floats, lattice constant for scaling 1.0 :param multi: bool default False are they multiple plots? Function specific parameters: :param marker_fit: defaults to `marker`, marker type for the fit data :param markersize_fit: defaults to `markersize`, markersize for the fit data :param linewidth_fit: defaults to `linewidth`, linewidth for the fit data :param plotlabel_fit: str label for the fit data Other Kwargs will be passed on to all :py:func:`single_scatterplot()` or :py:func:`multiple_scatterplots()` calls """ # TODO: make box which shows fit results. (fit resuls have to be past) # TODO: multiple plots in one use multi_scatter_plot for this... if 'plotlables' in kwargs: warnings.warn('plotlables is deprecated. Use plot_label and plot_label_fit instead', DeprecationWarning) if multi: plot_label = [] plot_label_fit = [] for indx in range(len(scaling)): plot_label.append(kwargs['plotlables'][2 * indx]) plot_label_fit.append(kwargs['plotlables'][2 * indx + 1]) kwargs['plot_label'] = plot_label kwargs['plot_label_fit'] = plot_label_fit else: kwargs['plot_label'] = kwargs['plotlables'][0] kwargs['plot_label_fit'] = kwargs['plotlables'][1] #print markersize_g if relative: if ref_const: xlabel = rf'Relative Volume [a/{ref_const}$\AA$]' else: xlabel = r'Relative Volume' else: xlabel = r'Volume [$\AA$]' if multi: ylabel = r'Total energy norm[0] [eV]' else: ylabel = r'Total energy [eV]' #Add custom parameters for fit plot_params.add_parameter('marker_fit', default_from='marker') plot_params.add_parameter('markersize_fit', default_from='markersize') plot_params.add_parameter('linewidth_fit', default_from='linewidth') plot_params.add_parameter('plot_label_fit') plot_params.set_defaults(default_type='function', marker_fit='s', plot_label='simulation data', plot_label_fit='fit results') general_keys = {'figure_kwargs', 'show', 'save_plots'} general_info = {key: val for key, val in kwargs.items() if key in general_keys} kwargs = {key: val for key, val in kwargs.items() if key not in general_keys} plot_params.set_parameters(**general_info) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) if multi: plot_params.single_plot = False plot_params.num_plots = len(scaling) plot_kw = plot_params.plot_kwargs(post_process=False) plot_fit_kw = plot_params.plot_kwargs(post_process=False, marker='marker_fit', markersize='markersize_fit', linewidth='linewidth_fit', plot_label='plot_label_fit') if multi: # TODO test if dim of total_e = dim of scaling, dim plot lables... # or parse on scaling? with NestedPlotParameters(plot_params): ax = multiple_scatterplots(scaling, total_energy, xlabel=xlabel, ylabel=ylabel, title=title, axis=ax, show=False, save_plots=False, **plot_kw, **kwargs) if fit_y: with NestedPlotParameters(plot_params): ax = multiple_scatterplots(scaling, fit_y, xlabel=xlabel, ylabel=ylabel, title=title, axis=ax, show=False, save_plots=False, **plot_fit_kw, **kwargs) else: with NestedPlotParameters(plot_params): ax = single_scatterplot(scaling, total_energy, xlabel=xlabel, ylabel=ylabel, title=title, axis=ax, show=False, save_plots=False, **plot_kw, **kwargs) if fit_y: with NestedPlotParameters(plot_params): ax = single_scatterplot(scaling, fit_y, xlabel, ylabel, title, axis=ax, show=False, save_plots=False, **plot_fit_kw, **kwargs) plot_params.draw_lines(ax) plot_params.save_plot(saveas) return ax
[docs]def plot_relaxation_results(): """ Plot from the result node of a relaxation workflow, All forces of every atom type versus relaxation cycle. Average force of all atom types versus relaxation cycle. Absolut relaxation in Angstroem of every atom type. Relative realxation of every atom type to a reference structure. (if none given use the structure from first relaxation cycle as reference) """ pass
[docs]@ensure_plotter_consistency(plot_params) def plot_dos(energy_grid, dos_data, *, saveas='dos_plot', energy_label=r'$E-E_F$ [eV]', dos_label=r'DOS [1/eV]', title=r'Density of states', xyswitch=False, e_fermi=0, **kwargs): """ Plot the provided data for a density of states (not spin-polarized). Can be done horizontally or vertical via the switch `xyswitch` :param energy_grid: arraylike data for the energy grid of the DOS :param dos_data: arraylike data for all the DOS components to plot :param title: str, Title of the plot :param energy_label: str, label for the energy-axis :param dos_label: str, label for the DOS-axis :param saveas: str, filename for the saved plot :param e_fermi: float (default 0), place the line for the fermi energy at this value :param xyswitch: bool if True, the enrgy axis will be plotted vertically All other Kwargs are passed on to the :py:func:`multiple_scatterplots()` call """ import seaborn as sns if 'limits' in kwargs: limits = kwargs.pop('limits') if xyswitch: limits['x'], limits['y'] = limits.pop('dos', None), limits.pop('energy', None) else: limits['x'], limits['y'] = limits.pop('energy', None), limits.pop('dos', None) kwargs['limits'] = {k: v for k, v in limits.items() if v is not None} lines = {'horizontal': 0} lines['vertical'] = e_fermi if xyswitch: lines['vertical'], lines['horizontal'] = lines['horizontal'], lines['vertical'] color_cycle = ('black',) + tuple(sns.color_palette('muted')) plot_params.set_defaults(default_type='function', marker=None, legend=True, lines=lines, color_cycle=color_cycle) if xyswitch: figsize = plot_params['figure_kwargs']['figsize'] plot_params.set_defaults(default_type='function', figure_kwargs={'figsize': figsize[::-1]}) if isinstance(dos_data[0], (list, np.ndarray)) and \ not isinstance(energy_grid[0], (list, np.ndarray)): energy_grid = [energy_grid] * len(dos_data) if xyswitch: x, y = dos_data, energy_grid xlabel, ylabel = dos_label, energy_label plot_params.set_defaults(default_type='function', area_vertical=True) else: xlabel, ylabel = energy_label, dos_label x, y = energy_grid, dos_data ax = multiple_scatterplots(x, y, xlabel=xlabel, ylabel=ylabel, title=title, saveas=saveas, **kwargs) return ax
[docs]@ensure_plotter_consistency(plot_params) def plot_spinpol_dos(energy_grid, spin_up_data, spin_dn_data, *, saveas='spinpol_dos_plot', energy_label=r'$E-E_F$ [eV]', dos_label=r'DOS [1/eV]', title=r'Density of states', xyswitch=False, energy_grid_dn=None, e_fermi=0, spin_dn_negative=True, **kwargs): """ Plot the provided data for a density of states (spin-polarized). Can be done horizontally or vertical via the switch `xyswitch` :param energy_grid: arraylike data for the energy grid of the DOS :param spin_up_data: arraylike data for all the DOS spin-up components to plot :param spin_dn_data: arraylike data for all the DOS spin-down components to plot :param title: str, Title of the plot :param energy_label: str, label for the energy-axis :param dos_label: str, label for the DOS-axis :param saveas: str, filename for the saved plot :param e_fermi: float (default 0), place the line for the fermi energy at this value :param xyswitch: bool if True, the enrgy axis will be plotted vertically :param energy_grid_dn: arraylike data for the energy grid of the DOS of the spin-down component (optional) :param spin_dn_negative: bool, if True (default) the spin-down components are plotted downwards All other Kwargs are passed on to the :py:func:`multiple_scatterplots()` call """ import seaborn as sns if 'limits' in kwargs: limits = kwargs.pop('limits') if xyswitch: limits['x'], limits['y'] = limits.pop('dos', None), limits.pop('energy', None) else: limits['x'], limits['y'] = limits.pop('energy', None), limits.pop('dos', None) kwargs['limits'] = {k: v for k, v in limits.items() if v is not None} if isinstance(spin_up_data[0], (list, np.ndarray)): if len(spin_up_data) != len(spin_dn_data): raise ValueError(f'Dimensions do not match: Spin-up: {len(spin_up_data)} Spin-dn: {len(spin_dn_data)}') max_dos = max(data.max() for data in spin_up_data) max_dos = max(max_dos, max(data.max() for data in spin_dn_data)) max_dos *= 1.1 if spin_dn_negative: if isinstance(spin_dn_data, np.ndarray): spin_dn_data *= -1 elif isinstance(spin_up_data[0], list): spin_dn_data = [-value for data in spin_dn_data for value in data] else: spin_dn_data = [-value for value in spin_dn_data] lines = {'horizontal': 0} lines['vertical'] = e_fermi if xyswitch: lines['vertical'], lines['horizontal'] = lines['horizontal'], lines['vertical'] if xyswitch: limits = {'x': (-max_dos, max_dos)} else: limits = {'y': (-max_dos, max_dos)} if isinstance(spin_up_data[0], (list, np.ndarray)): num_plots = len(spin_up_data) else: num_plots = 1 color_cycle = ('black',) + tuple(sns.color_palette('muted')) plot_params.set_defaults(default_type='function', marker=None, legend=True, lines=lines, limits=limits, repeat_colors_after=num_plots, color_cycle=color_cycle) if xyswitch: figsize = plot_params['figure_kwargs']['figsize'] plot_params.set_defaults(default_type='function', figure_kwargs={'figsize': figsize[::-1]}) save_keys = {'show', 'save_plots', 'save_format', 'save_options'} save_options = {key: val for key, val in kwargs.items() if key in save_keys} kwargs = {key: val for key, val in kwargs.items() if key not in save_keys} plot_params.set_parameters(**save_options) if xyswitch: plot_params.set_defaults(default_type='function', invert_xaxis=True) dos_data = spin_up_data if not isinstance(spin_up_data[0], (list, np.ndarray)): dos_data = [dos_data, spin_dn_data] else: dos_data = np.concatenate((dos_data, spin_dn_data), axis=0) if isinstance(dos_data[0], (list, np.ndarray)) and \ not isinstance(energy_grid[0], (list, np.ndarray)): energy_grid = [energy_grid] * len(dos_data) if xyswitch: x, y = dos_data, energy_grid xlabel, ylabel = dos_label, energy_label plot_params.set_defaults(default_type='function', area_vertical=True) else: xlabel, ylabel = energy_label, dos_label x, y = energy_grid, dos_data with NestedPlotParameters(plot_params): ax = multiple_scatterplots(x, y, xlabel=xlabel, ylabel=ylabel, title=title, save_plots=False, show=False, **kwargs) if xyswitch: ax.annotate(r'$\uparrow$', xy=(0.125, 0.9), xycoords='axes fraction', ha='center', va='center', size=40) ax.annotate(r'$\downarrow$', xy=(0.875, 0.9), xycoords='axes fraction', ha='center', va='center', size=40) else: ax.annotate(r'$\uparrow$', xy=(0.05, 0.875), xycoords='axes fraction', ha='center', va='center', size=40) ax.annotate(r'$\downarrow$', xy=(0.05, 0.125), xycoords='axes fraction', ha='center', va='center', size=40) plot_params.save_plot(saveas) return ax
[docs]@ensure_plotter_consistency(plot_params) def plot_bands(kpath, bands, *, size_data=None, special_kpoints=None, e_fermi=0, xlabel='', ylabel=r'$E-E_F$ [eV]', title='', saveas='bandstructure', markersize_min=0.5, markersize_scaling=5.0, scale_color=True, **kwargs): """ Plot the provided data for a bandstrucuture (non spin-polarized). Can be used to illustrate weights on bands via `size_data` :param kpath: arraylike data for the kpoint data :param bands: arraylike data for the eigenvalues :param size_data: arraylike data the weights to emphasize (optional) :param title: str, Title of the plot :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param saveas: str, filename for the saved plot :param e_fermi: float (default 0), place the line for the fermi energy at this value :param special_kpoints: list of tuples (str, float), place vertical lines at the given values and mark them on the x-axis with the given label :param markersize_min: minimum value used in scaling points for weight :param markersize_scaling: factor used in scaling points for weight :param scale_color: bool, if True (default) the weight will be additionally shown via a colormapping All other Kwargs are passed on to the :py:func:`multi_scatter_plot()` call """ if special_kpoints is None: special_kpoints = [] xticks = [] xticklabels = [] for label, pos in special_kpoints: if label in ('Gamma', 'g'): label = r'$\Gamma$' xticklabels.append(label) xticks.append(pos) if size_data is not None: ylimits = (-15, 15) if 'limits' in kwargs: if 'y' in kwargs['limits']: ylimits = kwargs['limits']['y'] weight_max = max(size_data[np.logical_and(bands > ylimits[0], bands < ylimits[1])]) if 'vmax' not in kwargs: kwargs['vmax'] = weight_max if scale_color: kwargs['color_data'] = copy.copy(size_data) plot_params.set_defaults(default_type='function', cmap='Blues') if 'cmap' not in kwargs: #Cut off the white end of the Blues/Reds colormap plot_params.set_defaults(default_type='function', sub_colormap=(0.15, 1.0)) size_data = (markersize_min + markersize_scaling * size_data / weight_max)**2 lines = {'vertical': xticks, 'horizontal': e_fermi} limits = {'x': (min(kpath), max(kpath)), 'y': (-15, 15)} plot_params.set_defaults(default_type='function', lines=lines, limits=limits, xticks=xticks, xticklabels=xticklabels, color='k', linewidth=0, line_options={'zorder': -1}, colorbar=False) ax = multi_scatter_plot(kpath, bands, size_data=size_data, xlabel=xlabel, ylabel=ylabel, title=title, saveas=saveas, **kwargs) return ax
[docs]@ensure_plotter_consistency(plot_params) def plot_spinpol_bands(kpath, bands_up, bands_dn, *, size_data=None, show_spin_pol=True, special_kpoints=None, e_fermi=0, xlabel='', ylabel=r'$E-E_F$ [eV]', title='', saveas='bandstructure', markersize_min=0.5, markersize_scaling=5.0, scale_color=True, **kwargs): """ Plot the provided data for a bandstrucuture (spin-polarized). Can be used to illustrate weights on bands via `size_data` :param kpath: arraylike data for the kpoint data :param bands_up: arraylike data for the eigenvalues (spin-up) :param bands_dn: arraylike data for the eigenvalues (spin-dn) :param size_data: arraylike data the weights to emphasize BOTH SPINS (optional) :param title: str, Title of the plot :param xlabel: str, label for the x-axis :param ylabel: str, label for the y-axis :param saveas: str, filename for the saved plot :param e_fermi: float (default 0), place the line for the fermi energy at this value :param special_kpoints: list of tuples (str, float), place vertical lines at the given values and mark them on the x-axis with the given label :param markersize_min: minimum value used in scaling points for weight :param markersize_scaling: factor used in scaling points for weight :param show_spin_pol: bool, if True (default) the two different spin channles will be shown in blue and red by default :param scale_color: bool, if True (default) the weight will be additionally shown via a colormapping All other Kwargs are passed on to the :py:func:`multi_scatter_plot()` call """ if special_kpoints is None: special_kpoints = {} if size_data is not None: if len(size_data) != 2: raise ValueError('size_data has to be a list of length 2') ylimits = (-15, 15) if 'limits' in kwargs: if 'y' in kwargs['limits']: ylimits = kwargs['limits']['y'] weight_max = max(size_data[0][np.logical_and(bands_up > ylimits[0], bands_up < ylimits[1])]) weight_max = max(weight_max, max(size_data[1][np.logical_and(bands_dn > ylimits[0], bands_dn < ylimits[1])])) if 'vmax' not in kwargs: kwargs['vmax'] = weight_max color_data = [] for indx, data in enumerate(size_data): color_data.append(copy.copy(data)) size_data[indx] = (markersize_min + markersize_scaling * data / weight_max)**2 if scale_color: kwargs['color_data'] = color_data xticks = [] xticklabels = [] for label, pos in special_kpoints: if label in ('Gamma', 'g'): label = r'$\Gamma$' xticklabels.append(label) xticks.append(pos) lines = {'vertical': xticks, 'horizontal': e_fermi} cmaps = None if show_spin_pol: color = ['blue', 'red'] if scale_color: cmaps = ['Blues', 'Reds'] else: color = 'k' if scale_color: cmaps = 'Blues' limits = {'x': (min(kpath), max(kpath)), 'y': (-15, 15)} plot_params.set_defaults(default_type='function', lines=lines, limits=limits, xticks=xticks, xticklabels=xticklabels, color=color, cmap=cmaps, linewidth=0, legend=True, line_options={'zorder': -1}, zorder=[2, 1], colorbar=False) if 'cmap' not in kwargs: #Cut off the white end of the Blues/Reds colormap plot_params.set_defaults(default_type='function', sub_colormap=(0.15, 1.0)) ax = multi_scatter_plot([kpath, kpath], [bands_up, bands_dn], size_data=size_data, xlabel=xlabel, ylabel=ylabel, title=title, saveas=saveas, **kwargs) return ax
[docs]def plot_certain_bands(): """ Plot only certain bands from a bands.1 file from FLEUR """ pass
[docs]def plot_bands_and_dos(): """ PLot a Bandstructure with a density of states on the right side. """ pass
[docs]def plot_corelevels(coreleveldict, compound='', axis=None, saveas='scatterplot', **kwargs): """ Ploting function to visualize corelevels and corelevel shifts """ for elem, corelevel_dict in coreleveldict.items(): # one plot for each element axis = plot_one_element_corelv(corelevel_dict, elem, compound=compound, axis=axis, saveas=saveas, **kwargs) return axis
[docs]@ensure_plotter_consistency(plot_params) def plot_one_element_corelv(corelevel_dict, element, compound='', axis=None, saveas='scatterplot', **kwargs): """ This routine creates a plot which visualizes all the binding energies of one element (and currenlty one corelevel) for different atomtypes. example: corelevels = {'W' : {'4f7/2' : [123, 123.3, 123.4 ,123.1], '4f5/2' : [103, 103.3, 103.4, 103.1]}, 'Be' : {'1s': [118, 118.2, 118.4, 118.1, 118.3]}} """ corelevels_names = [] xdata_all = [] ydata_all = [] for corelevel, corelevel_list in corelevel_dict.items(): #print corelevel n_atom = len(corelevel_list) x_axis = list(range(0, n_atom, 1)) y_axis = corelevel_list xdata_all.append(x_axis) ydata_all.append(y_axis) corelevels_names.append(corelevel) elem = element xdata = xdata_all[0] ydata = ydata_all[0] xlabel = f'{elem} atomtype' ylabel = 'energy in eV' title = f'Element: {elem} from {compound} cl {corelevels_names}' #plotlabel ='corelevel shifts' #linetyp='o-' xmin = xdata[0] - 0.5 xmax = xdata[-1] + 0.5 ymin = min(ydata) - 1 ymax = max(ydata) + 1 plot_params.set_defaults(default_type='function', font_options={'color': 'darkred'}, color='k', linewidth=2, limits={ 'x': (xmin, xmax), 'y': (ymin, ymax) }) kwargs = plot_params.set_parameters(continue_on_error=True, **kwargs) ax = plot_params.prepare_plot(title=title, xlabel=xlabel, ylabel=ylabel, axis=axis) for ydata in ydata_all: for x, y in zip(xdata, ydata): lenx = xmax - xmin length = 0.5 / lenx offset = 0.5 / lenx xminline = x / lenx + offset - length / 2 xmaxline = x / lenx + offset + length / 2 ax.axhline(y=y, xmin=xminline, xmax=xmaxline, linewidth=plot_params['linewidth'], color=plot_params['color']) text = r'{}'.format(y) ax.text(x - 0.25, y + 0.3, text, fontdict=plot_params['font_options']) plot_params.set_scale(ax) plot_params.set_limits(ax) plot_params.draw_lines(ax) plot_params.show_legend(ax) plot_params.save_plot(saveas) return ax
[docs]def construct_corelevel_spectrum(coreleveldict, natom_typesdict, exp_references=None, scale_to=-1, fwhm_g=0.6, fwhm_l=0.1, energy_range=None, xspec=None, energy_grid=0.2, peakfunction='voigt', alpha_l=1.0, beta_l=1.5): """ Constructrs a corelevel spectrum from a given corelevel dict :params: :returns: list: [xdata_spec, ydata_spec, ydata_single_all, xdata_all, ydata_all, xdatalabel] """ if energy_range is None: energy_range = (None, None) xdata_all = [] ydata_all = [] ydata_spec = [] xdata_spec = [] xdatalabel = [] energy_grid = round(energy_grid, 5) # eV #count = 0 #compound_info_new = compound_info for elem, corelevel_dict in coreleveldict.items(): natom = natom_typesdict.get(elem, 0) #elem_count = 0 for corelevel_name, corelevel_list in corelevel_dict.items(): # get number of electron if fully occ: nelectrons = 1 if 's' in corelevel_name: nelectrons = 2 else: max_state_occ_spin = {'1/2': 2, '3/2': 4, '5/2': 6, '7/2': 8} # check if spin in name for key, val in max_state_occ_spin.items(): if key in corelevel_name: nelectrons = val for i, corelevel in enumerate(corelevel_list): xdatalabel.append(elem + ' ' + corelevel_name) xdata_all.append(corelevel) ydata_all.append(natom[i] * nelectrons) #count = count + 1 #elem_count = elem_count + 1 ''' not working yet bad design if compound_info: for compound, element_dict in compound_info.iteritems(): for elemt, number in element_dict.iteritems(): print number if elemt == elem: # we need to set the index that we find it later, group it if isinstance(number, list): continue compound_info_new[compound][elemt] = [count-elem_count, count-elem_count+number] ''' xmin = min(xdata_all) - 2 #0.5 xmax = max(xdata_all) + 2 #0.5 if energy_range[0]: xmin = energy_range[0] if energy_range[1]: xmax = energy_range[1] # xdata_spec = np.array(np.arange(xmax,xmin, -energy_grid)) if xspec is not None: xdata_spec = xspec else: xdata_spec = np.array(np.arange(xmin, xmax + energy_grid, energy_grid)) ydata_spec = np.zeros(len(xdata_spec), dtype=float) ydata_single_all = [] for i, xpoint in enumerate(xdata_all): if peakfunction == 'gaus': data_f = np.array(gaussian(xdata_spec, fwhm_g, xpoint)) #, 1.0)) elif peakfunction == 'voigt': data_f = np.array(voigt_profile(xdata_spec, fwhm_g, fwhm_l, xpoint)) # different fwhn for g und l elif peakfunction == 'pseudo-voigt': data_f = np.array(pseudo_voigt_profile(xdata_spec, fwhm_g, fwhm_l, xpoint)) elif peakfunction == 'lorentz': data_f = np.array(lorentzian(xdata_spec, fwhm_l, xpoint)) elif peakfunction == 'doniach-sunjic': data_f = np.array(doniach_sunjic(xdata_spec, scale=1.0, E_0=xpoint, gamma=fwhm_l, alpha=fwhm_g)) elif peakfunction == 'asymmetric_lorentz_gauss_conv': #print(xpoint, xdata_spec) data_f = np.array( asymmetric_lorentz_gauss_conv(xdata_spec, xpoint, fwhm_g=fwhm_g, fwhm_l=fwhm_l, alpha=alpha_l, beta=beta_l)) else: print('given peakfunction type not known') data_f = [] return # sometimes we get a point to much if constructed from new mesh.. if len(ydata_spec) < len(data_f): # TODO: further adjustements? we assume only one point difference data_f = data_f[:-1] #print('length', len(ydata_spec), len(data_f)) #gaus_f = lorentzian(xdata_spec, xpoint, 0.6, 100.0) if peakfunction == 'doniach-sunjic': ydata_spec = ydata_spec + ydata_all[i] * data_f ydata_single_all.append(ydata_all[i] * data_f) else: ydata_spec = ydata_spec + ydata_all[i] * data_f ydata_single_all.append(ydata_all[i] * data_f) # we scale after and not before, because the max intensity is not neccesary # the number of electrons. if scale_to > 0.0: y_valmax = max(ydata_spec) scalingfactor = scale_to / y_valmax ydata_spec = ydata_spec * scalingfactor ydata_single_all_new = [] for ydata_single in ydata_single_all: ydata_single_all_new.append(ydata_single * scalingfactor) ydata_single_all = ydata_single_all_new return [xdata_spec, ydata_spec, ydata_single_all, xdata_all, ydata_all, xdatalabel]
[docs]@ensure_plotter_consistency(plot_params) def plot_corelevel_spectra(coreleveldict, natom_typesdict, exp_references=None, scale_to=-1, show_single=True, show_ref=True, energy_range=None, title='', fwhm_g=0.6, fwhm_l=0.1, energy_grid=0.2, peakfunction='voigt', linestyle_spec='-', marker_spec='o', color_spec='k', color_single='g', xlabel='Binding energy [eV]', ylabel='Intensity [arb] (natoms*nelectrons)', saveas=None, xspec=None, alpha_l=1.0, beta_l=1.0, **kwargs): """ Plotting function of corelevel in the form of a spectrum. Convention: Binding energies are positiv! Args: coreleveldict: dict of corelevels with a list of corelevel energy of atomstypes # (The given corelevel accounts for a weight (number of electrons for full occupied corelevel) in the plot.) natom_typesdict: dict with number of atom types for each entry Kwargs: exp_references: dict with experimental refereces, will be ploted as vertical lines show_single (bool): plot all single peaks. scale_to float: the maximum 'intensity' will be scaled to this value (useful for experimental comparisons) title (string): something for labeling fwhm (float): full width half maximum of peaks (gaus, lorentz or voigt_profile) energy_grid (float): energy resolution linetyp_spec : linetype for spectrum peakfunction (string): what the peakfunction should be {'voigt', 'pseudo-voigt', 'lorentz', 'gaus'} example: coreleveldict = {u'Be': {'1s1/2' : [-1.0220669053033051, -0.3185614920138805,-0.7924091040092139]}} n_atom_types_Be12Ti = {'Be' : [4,4,4]} """ #show_compound=True, , compound_info={} compound_info dict: dict that can be used to specify what component should be shown together compound_info = {'Be12Ti' : {'Be' : 4, 'Ti' : 1}, 'BeTi' : {'Be' : 1, 'Ti' : 1}} # TODO feature to make singles of different compounds a different color if energy_range is None: energy_range = (None, None) if exp_references is None: exp_references = {} [xdata_spec, ydata_spec, ydata_single_all, xdata_all, ydata_all, xdatalabel] = construct_corelevel_spectrum(coreleveldict, natom_typesdict, exp_references=exp_references, scale_to=scale_to, fwhm_g=fwhm_g, fwhm_l=fwhm_l, energy_range=energy_range, xspec=xspec, energy_grid=energy_grid, peakfunction=peakfunction, alpha_l=alpha_l, beta_l=beta_l) xmin = min(xdata_all) - 2 #0.5 xmax = max(xdata_all) + 2 #0.5 if energy_range[0]: xmin = energy_range[0] if energy_range[1]: xmax = energy_range[1] xdata = xdata_all ydata = ydata_all ymax2 = max(ydata_spec) + 1 ymin = -0.3 ymax = max(ydata) + 1 limits = {'x': (xmin, xmax), 'y': (ymin, ymax)} limits_spec = {'x': (xmin, xmax), 'y': (ymin, ymax2)} #title = title #'Spectrum of {}'.format(compound) """ # ToDo redesign to use multiple_scatterplot axis = multiple_scatterplots(ydata, xdata, xlabel, ylabel, title, plot_labels, linestyle='', marker='o', markersize=markersize_g, legend=legend_g, legend_option={}, saveas='mscatterplot', limits=limits, scale=[None, None], axis=None, xerr=None, yerr=None, colors=[], linewidth=[], xticks=[], title=title, xlabel=xlabel, ylabel=ylabel, **kwargs) """ #print len(xdata), len(ydata) if 'plot_label' not in kwargs: kwargs['plot_label'] = 'corelevel shifts' if 'linestyle' not in kwargs: kwargs['linestyle'] = '' if saveas is None: saveas = f'XPS_theo_{fwhm_g}_{title}' saveas1 = f'XPS_theo_2_{fwhm_g}_{title}' else: saveas1 = saveas[1] saveas = saveas[0] #################################### ##### PLOT 1, plot raw datapoints if not plot_params['show']: return [xdata_spec, ydata_spec, ydata_single_all, xdata_all, ydata_all, xdatalabel] states = [] if show_ref and exp_references: for elm, ref_list_dict in exp_references.items(): for state, ref_list in ref_list_dict.items(): states.extend(ref_list) ax = single_scatterplot(xdata_all, ydata_all, xlabel=xlabel, ylabel=ylabel, title=title, line_options={ 'color': 'k', 'linestyle': '-', 'linewidth': 2 }, lines={'vertical': { 'pos': states, 'ymin': 0, 'ymax': 0.1 }}, limits=limits, saveas=saveas, **kwargs) ''' TODO for j,y in enumerate(ydata_all): for i,x in enumerate(xdata): lenx = xmax-xmin length = 0.5/lenx offset = 0.5/lenx xminline = x/lenx + offset - length/2 xmaxline = x/lenx + offset + length/2 plt.axhline(y=y[i], xmin=xminline, xmax=xmaxline, linewidth=2, color='k') text = r'{}'.format(y[i]) plt.text(x-0.25, y[i]+0.3, text, fontdict=font) ''' ############################################################## ##### PLOT 2, plot spectra, voigts around datapoints ######### kwargs.pop('linestyle', None) kwargs.pop('marker', None) kwargs.pop('color', None) kwargs.pop('save', None) kwargs.pop('save_plots', None) ax2 = single_scatterplot(xdata_spec, ydata_spec, xlabel=xlabel, ylabel=ylabel, title=title, marker=marker_spec, linestyle=linestyle_spec, color=color_spec, line_options={ 'color': 'k', 'linestyle': '-', 'linewidth': 2 }, lines={'vertical': { 'pos': states, 'ymin': 0, 'ymax': 0.1 }}, show=False, save_plots=False, limits=limits_spec, **kwargs) if show_single: ax2 = multiple_scatterplots([xdata_spec] * len(ydata_single_all), ydata_single_all, xlabel=xlabel, ylabel=ylabel, title=title, show=False, save_plots=False, axis=ax2, linestyle='-', color=color_single, limits=limits_spec, **kwargs) '''TODO if show_compound and compound_info: for i,compound_data in enumerate(ydata_compound): plotlabel = compound_plot_label[i] plt.plot(xdata_spec, compound_data, '-', label=plotlabel, color = color, linewidth=linewidth_g1, markersize = markersize_g) ''' plot_params.save_plot(saveas1) # for plotting or file writting return [xdata_spec, ydata_spec, ydata_single_all, xdata_all, ydata_all, xdatalabel, ax, ax2]
[docs]def asymmetric_lorentz(x, fwhm, mu, alpha=1.0, beta=1.5): """ asymetric lorentz function L^alpha for x<=mu L^beta for x>mu See casexps LA """ index = 0 for i, entry in enumerate(x): if entry <= mu: index = i else: break ydata1 = lorentzian_one(x[:index], fwhm, mu)**alpha ydata2 = lorentzian_one(x[index:], fwhm, mu)**beta return np.array(list(ydata1) + list(ydata2))
[docs]def lorentzian_one(x, fwhm, mu): """ Returns a Lorentzian line shape at x with FWHM fwhm and mean mu """ return 1.0 / (1 + 4 * ((x - mu) / fwhm)**2)
[docs]def gauss_one(x, fwhm, mu): """ Returns a Lorentzian line shape at x with FWHM fwhm and mean mu """ x = np.array(x) return np.exp(-4 * np.log(2) * ((x - mu) / fwhm)**2)
[docs]def asymmetric_lorentz_gauss_sum(x, mu, fwhm_l, fwhm_g, alpha=1.0, beta=1.5): """ asymmetric Lorentzian with Gauss convoluted """ ygaus = np.array(gauss_one(x, fwhm_g, mu)) ylorentz = np.array(asymmetric_lorentz(x, fwhm_l, mu, alpha=alpha, beta=beta)) ydata = ylorentz + ygaus return ydata
[docs]def asymmetric_lorentz_gauss_conv(x, mu, fwhm_l, fwhm_g, alpha=1.0, beta=1.5): """ asymmetric Lorentzian with Gauss convoluted """ from scipy.signal import fftconvolve #from scipy import signal # only one function has to be translated # gaus has to be symmetric arround 0 for convolution # and on the same equidistant grid xstep = abs(round(x[-1] - x[-2], 6)) rangex = abs(x[-1] - x[0]) #print(xstep, rangex) xgaus = np.arange(-rangex / 2.0, rangex / 2.0 + xstep, xstep) #print(xgaus[:10], xgaus[-1]) ygaus = np.array(gauss_one(xgaus, fwhm_g, mu=0.0), dtype=np.float64) ylorentz = np.array(asymmetric_lorentz(x, fwhm_l, mu=mu, alpha=alpha, beta=beta), dtype=np.float64) ydata = np.convolve(ylorentz, ygaus, mode='same') return ydata
''' def asymmetric_lorentz_gauss_conv_interp(x, mu, fwhm_l,fwhm_g,alpha=1.0, beta=1.5, grid_factor=10): """ asymmetric Lorentzian with Gauss convoluted. Real convolution. For the convolution to work we construct a finer mesh, with mu shifted to 0.0 on which we convolute. Then we linear interpolate on the original mesh points. """ import numpy as np from scipy.interpolate import interp1d # convolution has to be symmetric arround 0 # check if xmu is right or left, # double longest side, shift xmu to 0.0 # then interpolate at original mesh points x = np.array(x, dtype=np.float64) xstep = round(x[-1]-x[-2],6) xstepmesh = xstep/grid_factor xmesh = np.arange(x[0], x[-1]+xstepmesh/2.0, xstepmesh) xmu = np.float64(0.0) muindex = 0 for i, en in enumerate(xmesh): if en <=mu: xmu = mu muindex = i else: break if muindex <= len(xmesh)/2.0: xtrans = np.arange(-x[-1] + xmu - xstep, x[-1] - xmu + xstep, xstepmesh) else: xtrans = np.arange(x[0] - xmu - xstep, -x[0] + xmu + xstep, xstepmesh) ygaus = np.array(gauss_one(xtrans, fwhm_g, mu=0.0), dtype=np.float64) ylorentz = np.array(asymmetric_lorentz(xtrans,fwhm_l, mu=0.0, alpha=alpha, beta=beta), dtype=np.float64) ydata = np.convolve(ylorentz,ygaus,mode='same') # iterpolate function and evalutate at original xdata f = interp1d(xtrans+xmu, ydata, assume_sorted=True) ydata_return = f(x) return ydata_return def asymmetric_lorentz_gauss_conv1(x, mu, fwhm_l,fwhm_g,alpha=1.0, beta=1.5): """ asymmetric Lorentzian with Gauss convoluted """ import numpy as np from scipy import signal ygaus = np.array(gauss_one(x, fwhm_g, mu)) ylorentz = np.array(asymmetric_lorentz(x,fwhm_l, mu, alpha=alpha, beta=beta)) #ydata = np.convolve(ylorentz,np.flip(ygaus, axis=0),mode='same') ydata = np.convolve(ylorentz,ygaus,mode='same') #ydata = ylorentz+ygaus #ydata = direct_convolution(ylorentz,ygaus) #ydata = signal.convolve(ylorentz,ygaus) return ydata def asymmetric_lorentz_gauss_conv_linear(x, mu, fwhm_l,fwhm_g,alpha=1.0, beta=1.5): """ asymmetric Lorentzian with Gauss convoluted """ import numpy as np #from scipy import signal # convolution has to be symmetric arround 0 # check if xmu is right or left, # double longest side, shift xmu to 0.0 # then shift back and cut off the rest # we asume equidistant mesh x = np.array(x, dtype=np.float64) xstep = round(x[-1]-x[-2],6) xmu = np.float64(0.0) muindex = 0 for i, en in enumerate(x): if en <=mu: xmu = en muindex = i else: break #print(x[0]-xmu, -x[0]+xmu,xstep) if muindex <= len(x)/2.0: xtrans = np.arange(-x[-1]+xmu, x[-1]-xmu,xstep) else: xtrans = np.arange(x[0]-xmu, -x[0]+xmu,xstep) # To keep mu continous we parse the exact mu to the lorentz and gauss... # the convolution will not be totally correct... # todo maybe combine with gridfactor... ygaus = np.array(gauss_one(xtrans, fwhm_g, mu=(xmu-mu)/2.0), dtype=np.float64) ylorentz = np.array(asymmetric_lorentz(xtrans,fwhm_l, mu=(xmu-mu)/2.0, alpha=alpha, beta=beta), dtype=np.float64) ydata = np.convolve(ylorentz,ygaus,mode='same') # shift data back... through cutting it if muindex <= len(x)/2.0: ydata_new = np.array(ydata[len(ydata)-len(x):], dtype=np.float64) else: ydata_new = np.array(ydata[:len(x)], dtype=np.float64) return ydata_new def asymmetric_lorentz_gauss_conv(x, mu, fwhm_l,fwhm_g,alpha=1.0, beta=1.5, grid_factor=10): """ asymmetric Lorentzian with Gauss convoluted """ import numpy as np #from scipy import signal # convolution has to be symmetric arround 0 # check if xmu is right or left, # double longest side, shift xmu to 0.0 # then shift back and cut off the rest # TODO: overall a bit slow, can we speed this up? # cone idea for speed up would be only increase the mesh fineness between the x where mu lives... # this way npoints is len(x)+gridfactor and not len(x)*gridfoctor # logic becomes harder... # convolution is n^2 # we asume equidistant mesh # we increase the mesh by a factor of grid_factor # because mu can only vary by the meshstep... x = np.array(x, dtype=np.float64) xstep = round(x[-1]-x[-2],6) xstepmesh = xstep/grid_factor xmesh1 = np.arange(x[0], x[-1]+xstepmesh/2.0, xstepmesh) xmesh = np.round(xmesh1, 6) xmu = np.float64(0.0) muindex = 0 for i, en in enumerate(xmesh): if en <=mu: xmu = en#mu muindex = i else: break if muindex <= len(xmesh)/2.0: xtrans = np.arange(-x[-1]+xmu, x[-1]-xmu,xstepmesh) else: xtrans = np.arange(x[0]-xmu, -x[0]+xmu,xstepmesh) ygaus = np.array(gauss_one(xtrans, fwhm_g, mu=0.0), dtype=np.float64) ylorentz = np.array(asymmetric_lorentz(xtrans,fwhm_l, mu=0.0, alpha=alpha, beta=beta), dtype=np.float64) ydata = np.convolve(ylorentz,ygaus,mode='same') # shift data back... through cutting it if muindex <= len(xmesh)/2.0: ydata_new = np.array(ydata[len(ydata)-len(xmesh):], dtype=np.float64) else: ydata_new = np.array(ydata[:len(xmesh)], dtype=np.float64) # back to original mesh ydata_return = ydata_new[0::grid_factor] return ydata_return def direct_convolution(a,b): """ convolution, a, b same length, arrays """ import numpy as np ydata = np.zeros(len(a)) for i, entry in enumerate(a): for j, entry2 in enumerate(a): ydata[i] = ydata[i] + (entry2*b[i-j]) return ydata '''
[docs]def doniach_sunjic(x, scale=1.0, E_0=0, gamma=1.0, alpha=0.0): """ Doniach Sunjic asymmetric peak function. tail to higher binding energies. param x: list values to evaluate this function param scale: multiply the function with this factor param E_0: position of the peak param gamma, 'lifetime' broadening param alpha: 'asymmetry' parametera See Doniach S. and Sunjic M., J. Phys. 4C31, 285 (1970) or http://www.casaxps.com/help_manual/line_shapes.htm """ arg = (E_0 - x) / gamma alpha2 = (1.0 - alpha) #scale = scale/(gamma**alpha2) don_su = np.cos(np.pi * alpha + alpha2 * np.arctan(arg)) / (1 + arg**2)**(alpha2 / 2) return np.array(scale * don_su)
[docs]def gaussian(x, fwhm, mu): """ Returns Gaussian line shape at x with FWHM fwhm and mean mu """ #hwhm = fwhm/2.0 sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) #return np.sqrt(np.log(2) / np.pi) / hwhm\ # * np.exp(-((x-mu) / hwhm)**2 * np.log(2)) return np.exp(-(x - mu)**2 / (2 * (sigma**2))) / (np.sqrt(2 * np.pi) * sigma)
[docs]def lorentzian(x, fwhm, mu): """ Returns a Lorentzian line shape at x with FWHM fwhm and mean mu """ hwhm = fwhm / 2.0 return hwhm / np.pi / ((x - mu)**2 + hwhm**2)
[docs]def voigt_profile(x, fwhm_g, fwhm_l, mu): """ Return the Voigt line shape at x with Lorentzian component FWHM fwhm_l and Gaussian component FWHM fwhm_g and mean mu. There is no closed form for the Voigt profile, but it is related to the real part of the Faddeeva function (wofz), which is used here. """ from scipy.special import wofz #pylint: disable=no-name-in-module hwhm_l = fwhm_l / 2.0 sigma = fwhm_g / (2 * np.sqrt(2 * np.log(2))) # complex 1j return np.real(wofz(((x - mu) + 1j * hwhm_l) / sigma / np.sqrt(2))) / sigma / np.sqrt(2 * np.pi)
[docs]def CDF_voigt_profile(x, fwhm_g, fwhm_l, mu): """ Cumulative distribution function of a voigt profile implementation of formula found here: https://en.wikipedia.org/wiki/Voigt_profile # TODO: is there an other way then to calc 2F2? # or is there an other way to calc the integral of wofz directly, or use different error functions. """ from scipy.special import erf #pylint: disable=no-name-in-module pass
[docs]def hyp2f2(a, b, z): """ Calculation of the 2F2() hypergeometric function, since it is not part of scipy with the identity 2. from here: https://en.wikipedia.org/wiki/Generalized_hypergeometric_function a, b,z array like inputs TODO: not clear to me how to do this... the identity is only useful if we mange the adjust the arguments in a way that we can use them... also maybe go for the special case we need first: 1,1,3/2;2;-z2 """ from scipy.special import hyp0f1 pass
[docs]def pseudo_voigt_profile(x, fwhm_g, fwhm_l, mu, mix=0.5): """ Linear combination of gaussian and loretzian instead of convolution Args: x: array of floats fwhm_g: FWHM of gaussian fwhm_l: FWHM of Lorentzian mu: Mean mix: ratio of gaus to lorentz, mix* gaus, (1-mix)*Lorentz """ #pseudo_voigt = [] if mix > 1: print('mix has to be smaller than 1.') return [] gaus = gaussian(x, fwhm_g, mu) lorentz = lorentzian(x, fwhm_l, mu) return mix * gaus + (1 - mix) * lorentz
class PDF(object): def __init__(self, pdf, size=(200, 200)): """Display a PDF file inside a Jupyter notebook. Note: alternative to using aiida.tools.visualization.Graph class. Example: https://aiida-tutorials.readthedocs.io/en/latest/pages/2020_Intro_Week/notebooks/querybuilder-tutorial.html#generating-a-provenance-graph Reference: https://stackoverflow.com/a/19470377/8116031 :example: >>> # !verdi node graph generate 23 >>> PDF('23.dot.pdf',size=(800,600)) :param pdf: :param size: """ self.pdf = pdf self.size = size def _repr_html_(self): return '<iframe src={0} width={1[0]} height={1[1]}></iframe>'.format(self.pdf, self.size) def _repr_latex_(self): return r'\includegraphics[width=1.0\textwidth]{{{0}}}'.format(self.pdf)
[docs]def plot_colortable(colors: typing.Dict, title: str, sort_colors: bool = True, emptycols: int = 0): """Plot a legend of named colors. Reference: https://matplotlib.org/3.1.0/gallery/color/named_colors.html :param colors: a dict color_name : color_value (hex str, rgb tuple, ...) :param title: plot title :param sort_colors: sort :param emptycols: :return: figure """ import matplotlib.colors as mcolors cell_width = 212 cell_height = 22 swatch_width = 48 margin = 12 topmargin = 40 # Sort colors by hue, saturation, value and name. if sort_colors is True: by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))), name) for name, color in colors.items()) names = [name for hsv, name in by_hsv] else: names = list(colors) n = len(names) ncols = 4 - emptycols nrows = n // ncols + int(n % ncols > 0) width = cell_width * 4 + 2 * margin height = cell_height * nrows + margin + topmargin dpi = 72 fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi) fig.subplots_adjust(margin / width, margin / height, (width - margin) / width, (height - topmargin) / height) ax.set_xlim(0, cell_width * 4) ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.) ax.yaxis.set_visible(False) ax.xaxis.set_visible(False) ax.set_axis_off() ax.set_title(title, fontsize=24, loc='left', pad=10) for i, name in enumerate(names): row = i % nrows col = i // nrows y = row * cell_height swatch_start_x = cell_width * col swatch_end_x = cell_width * col + swatch_width text_pos_x = cell_width * col + swatch_width + 7 ax.text(text_pos_x, y, name, fontsize=14, horizontalalignment='left', verticalalignment='center') ax.hlines(y, swatch_start_x, swatch_end_x, color=colors[name], linewidth=18) return fig