Source code for stk._internal.ea.plotters.selection

import pathlib
import typing
from collections import Counter
from collections.abc import Callable, Iterator, Set
from functools import wraps
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from stk._internal.ea.molecule_record import MoleculeRecord
from stk._internal.ea.selection.batch import Batch, BatchKey
from stk._internal.ea.selection.selectors.selector import Selector
from stk._internal.key_makers.inchi_key import InchiKey

plt.switch_backend("agg")

T = typing.TypeVar("T", bound=MoleculeRecord)


class SelectMethod(typing.Protocol[T]):
    def __call__(
        self,
        population: dict[T, float],
        included_batches: Set[BatchKey] | None = None,
        excluded_batches: Set[BatchKey] | None = None,
    ) -> Iterator[Batch[T]]:
        pass


[docs] class SelectionPlotter(typing.Generic[T]): """ Plots which molecule records a :class:`.Selector` selects. Examples: *Plotting Which Molecule Records Got Selected* .. testcode:: plotting-which-molecule-records-got-selected import stk # Make a selector. roulette = stk.Roulette(num_batches=10) # Make a population. population = { stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=( stk.BuildingBlock( smiles='BrCCBr', functional_groups=[stk.BromoFactory()], ), ), repeating_unit='A', num_repeating_units=2, ), ): i for i in range(100) } # Make a plotter. You do not have to assign it to a variable. stk.SelectionPlotter('roulette_counter', roulette) # Select the molecule records. selected = tuple(roulette.select(population)) # There should now be a file called "roulette_counter_1.png" # which shows a graph of all the selected records. # Select records again. selected2 = tuple(roulette.select(population)) # There should now be a file called "roulette_counter_2.png" # which shows a graph of all the selected molecules. # And so on every time you use "roulette.select()". .. testcode:: plotting-which-molecule-records-got-selected :hide: import os assert os.path.exists('roulette_counter_1.png') assert os.path.exists('roulette_counter_1.csv') assert os.path.exists('roulette_counter_2.png') assert os.path.exists('roulette_counter_2.csv') os.remove('roulette_counter_1.png') os.remove('roulette_counter_1.csv') os.remove('roulette_counter_2.png') os.remove('roulette_counter_2.csv') """ def __init__( self, filename: pathlib.Path | str, selector: Selector[T], x_label: str = "Molecule: InChIKey - Fitness Value", record_label: Callable[ [MoleculeRecord[Any], float], str ] = lambda record, fitness_value: ( f"{InchiKey().get_key(record.get_molecule())} - {fitness_value}" ), heat_map_value: Callable[ [MoleculeRecord[Any], float], float ] = lambda record, fitness_value: fitness_value, heat_map_label: str = "Fitness", order_by: Callable[ [MoleculeRecord[Any], float], float ] = lambda record, fitness_value: fitness_value, ) -> None: """ Parameters: filename: The basename of the files. This means it should not include file extensions. selector: The :class:`.Selector` whose selection of molecule records is plotted. x_label: The label use for the x axis. record_label: A function which takes a :class:`.MoleculeRecord` and its fitness value and returns a string, which is the label used for the :class:`.MoleculeRecord` on the plot. heat_map_value: A function which takes a :class:`.MoleculeRecord` and its fitness value and returns a value. The value is used for coloring the heat map used in the plot. heat_map_label: The label used for the heat map key. order_by: A function which takes a :class:`.MoleculeRecord` and its fitness value and returns a value. The value is used to sort the plotted records along the x-axis in descending order. """ self._plots = 0 self._filename = filename self._x_label = x_label self._record_label = record_label self._order_by = order_by self._heat_map_value = heat_map_value self._heat_map_label = heat_map_label selector.select = self._update_counter(selector.select) # type: ignore def _update_counter(self, select: SelectMethod[T]) -> SelectMethod[T]: """ Decorate :meth:`.Selector.select`. This is a decorator which makes sure that every time :meth:`.Selector.select` selects a :class:`.MoleculeRecord` a counter keeping track of selected records is updated. Parameters: select: The :meth:`Selector.select` method to decorate. Returns: The decorated :meth:`.Selector.select` method. """ @wraps(select) def inner( population: dict[T, float], included_batches: Set[BatchKey] | None = None, excluded_batches: Set[BatchKey] | None = None, ) -> Iterator[Batch[T]]: counter = Counter({record: 0 for record in population}) for selected in select( population=population, included_batches=included_batches, excluded_batches=excluded_batches, ): counter.update(selected) yield selected self._plot(population, counter) return inner def _plot(self, population: dict[T, float], counter: Counter) -> None: """ Plot a selection counter. Parameters: counter: A counter specifying which records were selected and how many times. """ self._plots += 1 sns.set(style="darkgrid") data = [] for record, selection_count in counter.items(): label = self._record_label(record, population[record]) data.append( pd.DataFrame( data={ self._x_label: label, "Number of Times Selected": selection_count, "order": self._order_by(record, population[record]), "heat_map": self._heat_map_value( record, population[record], ), }, index=[self._x_label], ) ) df = pd.concat(data, ignore_index=True) df = df.sort_values( ["Number of Times Selected", "order"], ascending=[False, False], ) norm = plt.Normalize(df["heat_map"].min(), df["heat_map"].max()) sm = plt.cm.ScalarMappable(cmap="magma_r", norm=norm) sm.set_array([]) df.to_csv(f"{self._filename}_{self._plots}.csv") fig, ax = plt.subplots(figsize=(11.7, 8.28)) sns.scatterplot( x="Number of Times Selected", y=self._x_label, hue="heat_map", palette="magma_r", data=df, size=[200 for _ in range(len(counter.keys()))], ax=ax, ) ax.get_legend().remove() # https://tinyurl.com/2p9drmkh plt.rcParams["axes.grid"] = False ax.figure.colorbar(sm, ax=ax).set_label(self._heat_map_label) plt.tight_layout() fig.savefig(f"{self._filename}_{self._plots}.png", dpi=fig.dpi) plt.close("all")