Source code for masci_tools.vis.parameters

# Copyright (c), Forschungszentrum Jülich GmbH, IAS-1/PGI-1, Germany.         #
#                All rights reserved.                                         #
# This file is part of the Masci-tools package.                               #
# (Material science tools)                                                    #
#                                                                             #
# The code is hosted on GitHub at   #
# For further information on the license, see the LICENSE.txt file.           #
# For further information please visit                      #
#                                                                             #
Here basic functionality is provided for setting default parameters for plotting
and ensuring consistent values for these
from __future__ import annotations

import copy
from functools import wraps
from contextlib import contextmanager
from collections import ChainMap
import warnings
import json

from typing import Any, Callable, Generator, TypeVar, cast, MutableMapping

from masci_tools.util.typing import FileLike

[docs]@contextmanager def NestedPlotParameters(plotter_object: Plotter) -> Generator[None, None, None]: """ Contextmanager for nested plot function calls Will reset function defaults and parameters to previous values after exiting :param plotter_object: Plotter instance """ #pylint: disable=protected-access assert isinstance(plotter_object, Plotter), \ 'The NestedPlotParameters contextmanager should only be used for Plotter objects' function_defaults_before = copy.deepcopy(plotter_object._function_defaults) parameters_before = copy.deepcopy(plotter_object._given_parameters) single_plot_before = plotter_object.single_plot num_plots_before = plotter_object.num_plots try: yield finally: #Also performed if exception is thrown?? plotter_object._function_defaults = function_defaults_before plotter_object._given_parameters = parameters_before plotter_object.single_plot = single_plot_before plotter_object.num_plots = num_plots_before
F = TypeVar('F', bound=Callable[..., Any]) """Generic Callable type"""
[docs]def ensure_plotter_consistency(plotter_object: Plotter) -> Callable[[F], F]: """ Decorator for plot functions to ensure that the Parameters are reset even if an error occurs in the function Additionally checks are performed that the parameters are reset after execution and the defaults are never changed in a plot function :param plotter_object: Plotter instance to be checked for consistency """ assert isinstance(plotter_object, Plotter), \ 'The ensure_plotter_consistency decorator should only be used for Plotter objects' def ensure_plotter_consistency_decorator(func: F) -> F: """ Decorator that adds checks on the Plotter object """ @wraps(func) def ensure_consistency(*args, **kwargs): """ If an error is encountered in the decorated function the parameters of the plotter object are reset to avoid unintended sideeffects Also after execution the defaults and parameters are checked to make sure they are consistent """ #pylint: disable=protected-access global_defaults_before = copy.deepcopy(plotter_object._user_defaults) try: res = func(*args, **kwargs) finally: plotter_object.remove_added_parameters() plotter_object.reset_parameters() plotter_object._function_defaults = {} if plotter_object._user_defaults != global_defaults_before: #Reset the changes plotter_object._user_defaults = global_defaults_before plotter_object.remove_added_parameters() plotter_object.reset_parameters() plotter_object._function_defaults = {} raise ValueError(f"Defaults have changed inside the plotting function '{func.__name__}'") return res return cast(F, ensure_consistency) return ensure_plotter_consistency_decorator
def _generate_plot_parameters_table(defaults: dict[str, Any], descriptions: dict[str, str]) -> str: """ Generate a table for the plotting parameters for the docstrings :param defaults: dict/chainmap with the defined defaults :param descriptions: dict with the description of the keys in defaults """ #yapf: disable table = [ '.. list-table:: Plot Parameters', ' :widths: 15 60 25', ' :header-rows: 1', ' :class: tight-table', '', ' * - Name', ' - Description', ' - Default value' ] for key, value in defaults.items(): if value is None: value = 'No Default' elif not isinstance(value, dict): value = f'``{value}``' descr = descriptions.get(key, 'No Description available') descr = descr.replace('{',' ``{') descr = descr.replace('}','}`` ') table.extend([f' * - ``{key}``', f' - {descr}']) if not isinstance(value, dict): table.append(f' - {value}') else: string_value = [f"'{key}': '{val}'," if isinstance(val, str) else f"'{key}': {val}," for key, val in value.items()] if string_value: string_value[0] = '{' + string_value[0] string_value[-1] = string_value[-1].rstrip(',') + '}' else: string_value = ['{}'] table.extend([' - .. code-block::', ''] + \ [f' {string}' for string in string_value]) table.append('') #yapf: enable return '\n'.join(table)
[docs]class Plotter: """ Base class for handling parameters for plotting methods. For different plotting backends a subclass can be created to represent the specific parameters of the backend. :param default_parameters: dict with hardcoded default parameters :param general_keys: set of str optional, defines parameters which are not allowed to change for each entry in the plot data Kwargs in the __init__ method are forwarded to :py:func:`Plotter.set_defaults()` to change the current defaults away from the hardcoded parameters. The Plotter class creates a hierarchy of dictionaries for lookups on this object utilizing the `ChainMap` from the `collections` module. The hierarchy is as follows (First entries take precedence over later entries): - `parameters`: set by :py:func:`~Plotter.set_parameters()` (usually arguments passed into function) - `user defaults`: set by :py:func:`~Plotter.set_defaults()` - `function defaults`: set by :py:func:`~Plotter.set_defaults()` with `default_type='function'` - `global defaults`: Hardcoded as fallback Only the `parameters` can represent parameters for multiple sets of plot calls. All others are used as fallback for specifying non-specified values for single plots The current parameters can be accessed by bracket indexing the class. A example of this is shown below. .. code-block:: python parameter_dict = {'fontsize': 16, 'linestyle': '-'} params = Plotter(parameter_dict) #Accessing a parameter print(params['fontsize']) # 16 #Modifying a parameter params['fontsize'] = 20 print(params['fontsize']) # 20 #Creating a parameter set for multiple plots #1. Set the properties to the correct values params.single_plot = False params.num_plots = 3 #2. Now we can set a property either by providing a list or a integer indexed dict # Both of the following examples set the linestyle of the second and third plot to '--' params['linestyle'] = [None, '--', '--'] params['linestyle'] = {1: '--', 2: '--'} # Not specified values are replaced with the default value for a single plot print(params['linestyle']) # ['-', '--', '--'] #In lists properties can also be indexed via tuples print(params[('linestyle', 0)]) # '-' print(params[('linestyle', 1)]) # '--' #Changes to the parameters and properties are reset params.reset_parameters() print(params['linestyle']) # '-' """ def __init__(self, default_parameters: dict[str, Any], general_keys: set[str] | None = None, key_descriptions: dict[str, str] | None = None, type_kwargs_mapping: dict[str, set[str]] | None = None, kwargs_postprocess_rename: dict[str, str] | None = None, **kwargs: Any) -> None: self._PLOT_DEFAULTS = copy.deepcopy(default_parameters) self._type_kwargs_mapping = {} if type_kwargs_mapping is not None: self._type_kwargs_mapping = type_kwargs_mapping self._kwargs_postprocess_rename = {} if kwargs_postprocess_rename is not None: self._kwargs_postprocess_rename = kwargs_postprocess_rename #ChainMap with three dictionaries on top # 1. function parameters # 2. global defaults # 3. function defaults # 4. Hardcoded defaults self._params: ChainMap[str, Any] = ChainMap({}, {}, {}, self._PLOT_DEFAULTS) self._single_plot = True self._num_plots = 1 self._added_parameters: set[str] = set() self._GENERAL_KEYS = set() if general_keys is not None: self._GENERAL_KEYS = general_keys self._DESCRIPTIONS = {} if key_descriptions is not None: self._DESCRIPTIONS = key_descriptions if kwargs: self.set_defaults(continue_on_error=True, **kwargs) def __getitem__(self, indices: str | tuple[str, int]) -> Any: """ Get the current value for the key :param indices: either str (specifies the key) or tuple of str and int (specifies the key and index to access) :returns: the current parameter for the given specification. If tuple is given and the parameter is a list the second item is used for the list index """ if isinstance(indices, tuple): if len(indices) != 2: raise ValueError('Only Key or (Key, Index) Indexing supported!') key, index = indices else: key = indices index = None try: value = self._params[key] if key not in self._given_parameters: if isinstance(self._function_defaults.get(key), list): value = self._function_defaults[key] if isinstance(value, list): if index is None: return value if index < len(value): return value[index] return value[0] return value except KeyError: return None
[docs] def get_multiple_kwargs(self, keys: set[str], ignore: str | list[str] | None = None) -> dict[str, Any]: """ Get multiple parameters and return them in a dictionary :param keys: set of keys to process :param ignore: str or list of str (optional), defines keys to ignore in the creation of the dict """ keys_used = copy.deepcopy(keys) if ignore is not None: if not isinstance(ignore, list): ignore = [ignore] for key in ignore: keys_used.discard(key) ret_dict = {} for key in keys_used: if self[key] is not None: ret_dict[key] = self[key] return ret_dict
[docs] def plot_kwargs(self, plot_type: str = 'default', ignore: str | list[str] | None = None, extra_keys: set[str] | None = None, post_process: bool = True, list_of_dicts: bool = True, **kwargs: str) -> Any: """ Creates a dict or list of dicts (for multiple plots) with the defined parameters for the plotting calls of different types :param plot_type: type of plot :param ignore: str or list of str (optional), defines keys to ignore in the creation of the dict :param extra_keys: optional set for additional keys to retrieve :param post_process: bool, if True the parameters are cleaned up for inserting them directly into bokeh plotting functions Kwargs are used to replace values by custom parameters: Example for using a custom markersize:: p = Plotter(type_kwargs_mapping={'default': {'marker'}}) p.add_parameter('marker_custom', default_from='marker') p.plot_kwargs(marker='marker_custom') This code snippet will return the standard parameters for a plot, but the value for the marker will be taken from the key `marker_custom` """ if plot_type not in self._type_kwargs_mapping: raise ValueError( f'Unknown plot type {plot_type}. The following are known: {list(self._type_kwargs_mapping.keys())}') kwargs_keys = self._type_kwargs_mapping[plot_type] if extra_keys is not None: kwargs_keys = kwargs_keys | extra_keys #Insert custom keys to retrieve kwargs_keys = kwargs_keys.copy() for key, replace_key in kwargs.items(): kwargs_keys.remove(key) kwargs_keys.add(replace_key) plot_kwargs = self.get_multiple_kwargs(kwargs_keys, ignore=ignore) #Rename replaced keys back to standard names for key, replace_key in kwargs.items(): custom_val = plot_kwargs.pop(replace_key, None) if custom_val is not None: plot_kwargs[key] = custom_val if not post_process: return plot_kwargs for old, new in self._kwargs_postprocess_rename.items(): if old in plot_kwargs: plot_kwargs[new] = plot_kwargs.pop(old) if list_of_dicts: plot_kwargs = self.dict_of_lists_to_list_of_dicts(plot_kwargs, self.single_plot, self.num_plots) #type:ignore[assignment] return plot_kwargs
[docs] @staticmethod def dict_of_lists_to_list_of_dicts(dict_of_lists: dict[str, list[Any]], single_plot: bool, num_plots: int, repeat_after: int | None = None, ignore_repeat: set[str] | None = None) -> list[dict[str, Any]]: """ Converts dict of lists and single values to list of length num_plots or single dict for single_plot=True :param dict_of_lists: dict to be converted :param single_plot: boolean, if True only a single parameter set is allowed :param num_plots: int of the number of allowed plots :returns: list of dicts """ if ignore_repeat is None: ignore_repeat = set() any_list = any(isinstance(val, (list, tuple)) for val in dict_of_lists.values()) #Make sure that every entry is actually a list if any_list: for key, val in dict_of_lists.items(): if not isinstance(val, (list, tuple)): dict_of_lists[key] = [val] * num_plots elif not single_plot: dict_of_lists = {key: [value] for key, value in dict_of_lists.items()} list_of_dicts: list[dict[str, Any]] = dict_of_lists #type:ignore[assignment] # For single plot these are equivalent if not single_plot: list_of_dicts = [] # enforce that all lists of the same lengths maxlen = max(map(len, dict_of_lists.values())) if repeat_after is not None: maxlen = max(num_plots, maxlen) for index in range(maxlen): tempdict = {} # don't use comprehension here, otherwise the wrong key is caught for key, value in dict_of_lists.items(): try: if repeat_after is not None and index >= repeat_after and key not in ignore_repeat: tempdict[key] = value[index % repeat_after] else: tempdict[key] = value[index] except IndexError as ex: raise IndexError(f'List under key: {key} index: {index} out of range, ' f'should have length: {maxlen}. ' 'It may also be that some other list is just to long.') from ex list_of_dicts.append(tempdict) if len(list_of_dicts) != num_plots: if len(list_of_dicts) == 1: list_of_dicts = [copy.deepcopy(list_of_dicts[0]) for i in range(num_plots)] else: raise ValueError('Length does not match number of plots') return list_of_dicts
[docs] @staticmethod def convert_to_complete_list(given_value: Any, single_plot: bool, num_plots: int, default: Any = None, key: str = '') -> Any: """ Converts given value to list with length num_plots with None for the non-specified values :param given_value: value passed in, for multiple plots either list or dict with integer keys :param single_plot: bool, if True only a single parameter is allowed :param num_plots: int, if single_plot is False this defines the number of plots :param default: default value for unspecified entries :param key: str of the key to process """ if not isinstance(given_value, dict) and not isinstance(given_value, list): return given_value ret_value = copy.copy(given_value) if isinstance(given_value, dict) and all(isinstance(key, int) for key in given_value): if single_plot: raise ValueError(f"Got dict with integer indices for '{key}' but only a single plot is allowed") #Convert to list with defaults for not specified keys ret_value = [ret_value[indx] if indx in ret_value else None for indx in range(num_plots)] if isinstance(ret_value, list): if single_plot: raise ValueError(f"Got list for key '{key}' but only a single plot is allowed") if len(ret_value) != num_plots: ret_value = ret_value.copy() + [None] * (num_plots - len(ret_value)) if isinstance(default, list): ret_value = [val if val is not None else default[indx] for indx, val in enumerate(ret_value)] else: ret_value = [val if val is not None else default for val in ret_value] return ret_value
[docs] def expand_parameters(self, original_length: int, **kwargs: Any) -> dict[str, Any]: """ Expand parameters to a bigger number of plots. New length has to be a multiple of original length. Only lists of length <= orginal_length are expanded. Also expands function defaults :param orginal_length: int of the old length :param kwargs: arguments to expand :returns: expanded kwargs """ if self.num_plots == original_length: return kwargs if self.num_plots % original_length != 0: raise ValueError(f"Cannot expand parameters from length '{original_length}' to '{self.num_plots}'") length_per_param = self.num_plots // original_length for key, val in kwargs.items(): if self.is_general(key): continue if isinstance(val, list): if len(val) <= original_length: new_val = [] for val_list in val: new_val += [val_list] * length_per_param kwargs[key] = new_val for key, val in self._function_defaults.items(): if self.is_general(key): continue if isinstance(val, list): if len(val) == original_length: new_val = [] for val_list in val: new_val += [val_list] * length_per_param self.set_single_default(key, new_val, default_type='function') return kwargs
[docs] def set_single_default(self, key: str, value: Any, default_type: str = 'global') -> None: """ Set default value for a single key/value pair :param key: str of the key to set :param value: value to set the key to :default_type: either 'global' or 'function'. Specifies, whether to set the global defaults (not reset after function) or the function defaults """ if key not in self._params: raise KeyError(f'Unknown parameter: {key}') if default_type == 'global': self.__update_map(self._params.parents, key, value) elif default_type == 'function': if not self.is_general(key): default_val = self._hardcoded_defaults.get(key) if key in self._user_defaults: default_val = self._user_defaults[key] value = self.convert_to_complete_list(value, self.single_plot, self.num_plots, default=default_val, key=key) self.__update_map(self._params.parents.parents, key, value)
def __setitem__(self, key: str, value: Any) -> None: """ Set the given key value pair on the `Plotter._params` ChainMap (Always to the top layer). Unknown keys are forbidden. Keys allowed for multiple plot sets are converted to complete lists :param key: key to update :param value: value to use for updating """ if key not in self._params: raise KeyError(f'Unknown parameter: {key}') if not self.is_general(key): value = self.convert_to_complete_list(value, self.single_plot, self.num_plots, default=self._params.parents[key], key=key) self.__update_map(self._params, key, value) @staticmethod def __update_map(map_to_change: MutableMapping[str, Any], key: str, value: Any) -> None: """ Updates the given map with the given key value pair If the value is a dict it will be merged :param map_to_change: Mapping to change :param key: key to change :param value: value for updating the key """ if isinstance(map_to_change[key], dict): dict_before = copy.deepcopy(map_to_change[key]) if not isinstance(value, dict): if isinstance(value, list): map_to_change[key] = dict_before else: raise ValueError(f"Expected a dict for key {key} got '{value}'") else: dict_before.update(value) map_to_change[key] = dict_before else: map_to_change[key] = value
[docs] def set_defaults(self, continue_on_error: bool = False, default_type: str = 'global', **kwargs: Any) -> dict[str, Any]: """ Set the current defaults. This method will only work if the parameters are not changed from the defaults. Otherwise a error is raised. This is because after changing the defaults the changes will be propagated to the parameters to ensure consistency. :param continue_on_error: bool, if True unknown key are simply skipped :default_type: either 'global' or 'function'. Specifies, whether to set the global defaults (not reset after function) or the function defaults Kwargs are used to set the defaults. """ kwargs_unprocessed = copy.deepcopy(kwargs) if default_type == 'global': defaults_before = copy.deepcopy(self._user_defaults) elif default_type == 'function': defaults_before = copy.deepcopy(self._function_defaults) for key, value in kwargs.items(): try: self.set_single_default(key, value, default_type=default_type) kwargs_unprocessed.pop(key) except KeyError as err: if not continue_on_error: if default_type == 'global': self._user_defaults = defaults_before elif default_type == 'function': self._function_defaults = defaults_before raise KeyError(f'Unknown parameter: {key}') from err if 'extra_kwargs' in kwargs_unprocessed: extra_kwargs = kwargs_unprocessed.pop('extra_kwargs') kwargs_unprocessed.update(extra_kwargs) return kwargs_unprocessed
[docs] def set_parameters(self, continue_on_error: bool = False, **kwargs: Any) -> dict[str, Any]: """ Set the current parameters. :param continue_on_error: bool, if True unknown key are simply skipped and returned Kwargs are used to set the defaults. """ params_before = copy.deepcopy(self._given_parameters) kwargs_unprocessed = copy.deepcopy(kwargs) for key, value in kwargs.items(): try: self[key] = value kwargs_unprocessed.pop(key) except KeyError: if not continue_on_error: self._given_parameters = params_before raise if 'extra_kwargs' in kwargs_unprocessed: extra_kwargs = kwargs_unprocessed.pop('extra_kwargs') kwargs_unprocessed.update(extra_kwargs) return kwargs_unprocessed
[docs] def add_parameter(self, name: str, default_from: str | None = None, default_val: Any = None) -> None: """ Add a new parameter to the parameters dictionary. :param name: str name of the parameter :param default_from: str (optional), if given a entry is created in the current defaults with the name and the default value of the key `default_from` """ if default_val is not None: if default_from is not None: raise ValueError('Default value specified via default_val and default_from. Please choose one option') elif default_from is not None: default_val = self._params.parents[default_from] if isinstance(default_val, (dict, list)): default_val = copy.deepcopy(default_val) self._added_parameters.add(name) self._function_defaults[name] = default_val
[docs] def save_defaults(self, filename: FileLike = 'plot_defaults.json', save_complete: bool = False) -> None: """ Save the current defaults to a json file. :param filename: filename, where the defaults should be stored :param save_complete: bool if True not only the overwritten user defaults but also the unmodified hardcoded defaults are stored """ if save_complete: if self._function_defaults != {}: raise ValueError('Function defaults need to be empty before saving defaults') dict_to_save = dict(self._params.parents) else: dict_to_save = dict(self._user_defaults) with open(filename, 'w', encoding='utf-8') as file: #type:ignore[arg-type] json.dump(dict_to_save, file, indent=4, sort_keys=True)
[docs] def load_defaults(self, filename: FileLike = 'plot_defaults.json') -> None: """ Load defaults from a json file. :param filename: filename,from where the defaults should be taken """ with open(filename, encoding='utf-8') as file: #type:ignore[arg-type] param_dict = json.load(file) self.set_defaults(**param_dict)
[docs] def remove_added_parameters(self) -> None: """ Remove the parameters added via :py:func:`Plotter.add_parameter()` """ for key in copy.deepcopy(self._added_parameters): self._function_defaults.pop(key, None) self._given_parameters.pop(key, None)
[docs] def reset_defaults(self) -> None: """ Resets the defaults to the hardcoded defaults in _PLOT_DEFAULTS. """ self._params = ChainMap({}, {}, {}, self._PLOT_DEFAULTS)
[docs] def reset_parameters(self) -> None: """ Reset the parameters to the current defaults. The properties single_plot and num_plots are also set to default values """ self._given_parameters = {} #Reset number of plots properties self.single_plot = True self.num_plots = 1
[docs] def get_dict(self) -> dict[str, Any]: """ Return the dictionary of the current defaults. For use of printing """ return dict(self._params)
[docs] def get_description(self, key: str) -> None: """ Get the description of the given key :param key: str of the key, for which the description should be printed """ if key in self._DESCRIPTIONS: print(f'{key}:\n\n{self._DESCRIPTIONS[key]}') elif key in self._params: print(f'{key}:\n\nNo Description available') else: warnings.warn(f'{key} is not a known parameter')
[docs] def is_general(self, key: str) -> bool: """ Return, whether the key is general (meaning only related to the whole plots) :param key: str of the key to check :returns: bool, whether the key is general """ return key in self._GENERAL_KEYS
@property def _hardcoded_defaults(self) -> MutableMapping[str, Any]: """ Alias for the lowest map in the _params ChainMap """ return self._params.maps[3] @_hardcoded_defaults.setter def _hardcoded_defaults(self, dict_value: dict[str, Any]) -> None: """ Setter for the _hardcoded_defaults property """ self._params.maps[2] = dict_value @property def _function_defaults(self) -> MutableMapping[str, Any]: """ Alias for the second lowest map in the _params ChainMap """ return self._params.maps[2] @_function_defaults.setter def _function_defaults(self, dict_value: dict[str, Any]) -> None: """ Setter for the _function_defaults property """ self._params.maps[2] = dict_value @property def _user_defaults(self) -> MutableMapping[str, Any]: """ Alias for the third lowest map in the _params ChainMap """ return self._params.maps[1] @_user_defaults.setter def _user_defaults(self, dict_value: dict[str, Any]) -> None: """ Setter for the _user_defaults property """ self._params.maps[1] = dict_value @property def _given_parameters(self) -> MutableMapping[str, Any]: """ Alias for the highest map in the _params ChainMap """ return self._params.maps[0] @_given_parameters.setter def _given_parameters(self, dict_value: dict[str, Any]) -> None: """ Setter for the _given_parameters property """ self._params.maps[0] = dict_value @property def single_plot(self) -> bool: """ Boolean property if True only a single Plot parameter set is allowed """ return self._single_plot @single_plot.setter def single_plot(self, boolean_value: bool) -> None: """ Setter for single_plot property """ self._single_plot = boolean_value @property def num_plots(self) -> int: """ Integer property for number of plots produced """ return self._num_plots @num_plots.setter def num_plots(self, int_value: int) -> None: """ Setter for num_plots property """ self._num_plots = int_value