Source code for stk._internal.ea.selection.selectors.above_average

import itertools
import typing
from collections.abc import Callable, Iterable, Iterator, Sequence

import numpy as np

from stk._internal.ea.molecule_record import MoleculeRecord
from stk._internal.ea.selection.batch import Batch
from stk._internal.ea.selection.selectors.yielded_batches import YieldedBatches
from stk._internal.key_makers.inchi import Inchi
from stk._internal.key_makers.molecule import MoleculeKeyMaker

from .selector import Selector

T = typing.TypeVar("T", bound=MoleculeRecord)


[docs] class AboveAverage(Selector[T]): """ Yields above average batches of molecules. Examples: *Yielding Single Molecule Batches* Yielding molecules one at a time. For example, if molecules need to be selected for mutation or the next generation .. testcode:: yielding-single-molecule-batches import stk # Make the selector. above_avg = stk.AboveAverage() population = { stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=[ stk.BuildingBlock('BrCCBr', stk.BromoFactory()), ], repeating_unit='A', num_repeating_units=2, ), ): 1, stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=[ stk.BuildingBlock('BrCCBr', stk.BromoFactory()), ], repeating_unit='A', num_repeating_units=2, ), ): 2, } # Select the molecules. for selected, in above_avg.select(population): # Do stuff with each selected molecule. pass *Yielding Batches Holding Multiple Molecules* Yielding multiple molecules at once. For example, if molecules need to be selected for crossover. .. testcode:: yielding-batches-holding-multiple-molecules import stk # Make the selector. above_avg = stk.AboveAverage(batch_size=2) population = { stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=[ stk.BuildingBlock('BrCCBr', stk.BromoFactory()), ], repeating_unit='A', num_repeating_units=2, ), ): 1, stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=[ stk.BuildingBlock('BrCCBr', stk.BromoFactory()), ], repeating_unit='A', num_repeating_units=2, ), ): 2, } # Select the molecules. for selected1, selected2 in above_avg.select(population): # Do stuff with the selected molecules. pass """ def __init__( self, num_batches: int | None = None, batch_size: int = 1, duplicate_molecules: bool = True, duplicate_batches: bool = True, key_maker: MoleculeKeyMaker = Inchi(), fitness_modifier: Callable[ [dict[T, float]], dict[T, float] ] = lambda x: x, ) -> None: """ Parameters: num_batches: The number of batches to yield. If ``None`` then yielding will continue forever or until the generator is exhausted, whichever comes first. batch_size: The number of molecule records in each yielded :class:`.Batch`. duplicate_molecules: If ``True`` the same molecule can be yielded in more than one batch. duplicate_batches: If ``True`` the same batch can be yielded more than once. key_maker: Used to get the keys of molecules. If two molecules have the same key, they are considered duplicates. fitness_modifier: Takes the `population` on which :meth:`~Selector.select` is called and returns a :class:`dict`, which maps records in the `population` to the fitness values the :class:`.Selector` should use. """ super().__init__(key_maker, fitness_modifier, batch_size) self._duplicate_molecules = duplicate_molecules self._duplicate_batches = duplicate_batches self._num_batches = num_batches def _select_from_batches( self, batches: Sequence[Batch[T]], yielded_batches: YieldedBatches[T], ) -> Iterator[Batch[T]]: mean = np.mean([batch.get_fitness_value() for batch in batches]) # Yield highest fitness batches first. selected_batches: Iterable[Batch[T]] = sorted(batches, reverse=True) # Yield only batches with a fitness larger than the mean. selected_batches = itertools.takewhile( lambda batch: batch.get_fitness_value() > mean, selected_batches ) # Yield batches which are multiple times better than the mean # multiple times. selected_batches = ( batch for batch in selected_batches for _ in range(self._get_num_duplicates(batch, float(mean))) ) # If duplicate molecules are not allowed, allow only # batches with no yielded molecules. if not self._duplicate_molecules: selected_batches = filter( yielded_batches.has_no_yielded_molecules, selected_batches, ) # If duplicate batches are not allowed, allow only # unyielded batches. if not self._duplicate_batches: selected_batches = filter( yielded_batches.is_unyielded_batch, selected_batches, ) # Limit the number of yielded batches to _num_batches. yield from itertools.islice(selected_batches, self._num_batches) def _get_num_duplicates(self, batch: Batch[T], mean: float) -> int: if self._duplicate_batches and self._duplicate_molecules: return int(batch.get_fitness_value() // mean) return 1