Source code for stk.ea.plotters.selection

"""
Selection Plotter
=================

"""

from collections import Counter
from functools import wraps

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from stk.molecular import InchiKey

plt.switch_backend("agg")


[docs]class SelectionPlotter: """ Plots which molecule records a :class:`.Selector` selects. Examples -------- *Plotting Which Molecule Records Got Selected* .. testcode:: plotting-which-molecule-records-got-selected import stk # Make a selector. roulette = stk.Roulette(num_batches=10) # Make a population. population = tuple( stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=( stk.BuildingBlock( smiles='BrCCBr', functional_groups=[stk.BromoFactory()], ), ), repeating_unit='A', num_repeating_units=2, ), ).with_fitness_value(i) for i in range(100) ) # Make a plotter. You do not have to assign it to a variable. stk.SelectionPlotter('roulette_counter', roulette) # Select the molecule records. selected = tuple(roulette.select(population)) # There should now be a file called "roulette_counter_1.png" # which shows a graph of all the selected records. # Select records again. selected2 = tuple(roulette.select(population)) # There should now be a file called "roulette_counter_2.png" # which shows a graph of all the selected molecules. # And so on every time you use "roulette.select()". .. testcode:: plotting-which-molecule-records-got-selected :hide: import os assert os.path.exists('roulette_counter_1.png') assert os.path.exists('roulette_counter_1.csv') assert os.path.exists('roulette_counter_2.png') assert os.path.exists('roulette_counter_2.csv') os.remove('roulette_counter_1.png') os.remove('roulette_counter_1.csv') os.remove('roulette_counter_2.png') os.remove('roulette_counter_2.csv') """
[docs] def __init__( self, filename, selector, x_label="Molecule: InChIKey - Fitness Value", record_label=lambda record: ( f"{InchiKey().get_key(record.get_molecule())} - " f"{record.get_fitness_value()}" ), heat_map_value=lambda record: record.get_fitness_value(), heat_map_label="Fitness", order_by=lambda record: record.get_fitness_value(), ): """ Initialize a :class:`.SelectionPlotter` instance. Parameters ---------- filename : :class:`str` The basename of the files. This means it should not include file extensions. selector : :class:`.Selector` The :class:`.Selector` whose selection of molecule records is plotted. x_label : :class:`str`, optional The label use for the x axis. record_label : :class:`callable`, optional A :class:`callable` which takes a :class:`.MoleculeRecord` for each record, which is to be included on the x-axis of the counter plot. It should return a string, which is the label used for the :class:`.MoleculeRecord` on the plot. heat_map_value : :class:`callable`, optional A :class:`callable`, which takes a :class:`.MoleculeRecord` for each record, which is to be included on the x-axis, and returns a value. The value is used for coloring the heat map used in the plot. heat_map_label : :class:`str`, optional The label used for the heat map key. order_by : :class:`callable`, optional A :class:`callable`, which takes a :class:`.MoleculeRecord` for each record, which is to be included on the x-axis, and returns a value. The value is used to sort the plotted records along the x-axis in descending order. """ self._plots = 0 self._filename = filename self._x_label = x_label self._record_label = record_label self._order_by = order_by self._heat_map_value = heat_map_value self._heat_map_label = heat_map_label selector.select = self._update_counter(selector.select)
def _update_counter(self, select): """ Decorate :meth:`.Selector.select`. This is a decorator which makes sure that every time :meth:`.Selector.select` selects a :class:`.MoleculeRecord` a counter keeping track of selected records is updated. Parameters ---------- select : :class:`callable` The :meth:`Selector.select` method to decorate. Returns ------- :class:`function` The decorated :meth:`.Selector.select` method. """ @wraps(select) def inner(population, *args, **kwargs): counter = Counter({record: 0 for record in population}) for selected in select(population, *args, **kwargs): counter.update(selected) yield selected self._plot(population, counter) return inner def _plot(self, population, counter): """ Plot a selection counter. Parameters ---------- population : :class:`tuple` of :class:`.MoleculeRecord` The population from which molecule records were selected. counter : :class:`collections.Counter` A counter specifying which records were selected and how many times. Returns ------- None : :class:`NoneType` """ self._plots += 1 sns.set(style="darkgrid") data = [] for record, selection_count in counter.items(): label = self._record_label(record) data.append( pd.DataFrame( data={ self._x_label: label, "Number of Times Selected": selection_count, "order": self._order_by(record), "heat_map": self._heat_map_value(record), }, index=[self._x_label], ) ) df = pd.concat(data, ignore_index=True) df = df.sort_values( ["Number of Times Selected", "order"], ascending=[False, False], ) norm = plt.Normalize( df["heat_map"].min(), df["heat_map"].max() ) sm = plt.cm.ScalarMappable(cmap="magma_r", norm=norm) sm.set_array([]) df.to_csv(f"{self._filename}_{self._plots}.csv") fig, ax = plt.subplots(figsize=(11.7, 8.28)) sns.scatterplot( x="Number of Times Selected", y=self._x_label, hue="heat_map", palette="magma_r", data=df, s=[200 for i in range(len(counter.keys()))], ax=ax, ) ax.get_legend().remove() # https://tinyurl.com/2p9drmkh plt.rcParams["axes.grid"] = False ax.figure.colorbar(sm).set_label(self._heat_map_label) plt.tight_layout() fig.savefig(f"{self._filename}_{self._plots}.png", dpi=fig.dpi) plt.close("all")