Source code for stk._internal.ea.selection.selectors.tournament

import typing
from collections.abc import Callable, Iterator, Sequence

import numpy as np

from stk._internal.ea.molecule_record import MoleculeRecord
from stk._internal.ea.selection.batch import Batch
from stk._internal.ea.selection.selectors.yielded_batches import YieldedBatches
from stk._internal.key_makers.inchi import Inchi
from stk._internal.key_makers.molecule import MoleculeKeyMaker

from .selector import Selector

T = typing.TypeVar("T", bound=MoleculeRecord)


[docs] class Tournament(Selector[T]): """ Yields batches of molecules through tournament selection. In tournament selection, a random number of batches is chosen from the population to undergo a competition. In each competition, the batch with the highest fitness value is yielded. This is repeated until `num_batches` are yielded. Examples: *Yielding Single Molecule Batches* Yielding molecules one at a time. For example, if molecules need to be selected for mutation or the next generation. .. testcode:: yielding-single-molecule-batches import stk # Make the selector. tournament = stk.Tournament( num_batches=5, batch_size=1 ) population = { stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=[ stk.BuildingBlock('BrCCBr', stk.BromoFactory()), ], repeating_unit='A', num_repeating_units=2, ), ): i for i in range(100) } # Select the molecules. for selected, in tournament.select(population): # Do stuff with each selected molecule. pass """ def __init__( self, num_batches: int | None = None, batch_size: int = 1, duplicate_molecules: bool = True, duplicate_batches: bool = True, key_maker: MoleculeKeyMaker = Inchi(), fitness_modifier: Callable[ [dict[T, float]], dict[T, float] ] = lambda x: x, random_seed: int | np.random.Generator | None = None, ) -> None: """ Parameters: num_batches: The number of batches to yield. If ``None`` then yielding will continue forever or until the generator is exhausted, whichever comes first. batch_size: The number of molecules yielded at once. duplicate_molecules: If ``True`` the same molecule can be yielded in more than one batch. duplicate_batches: If ``True`` the same batch can be yielded more than once. key_maker: Used to get the keys of molecules. If two molecules have the same key, they are considered duplicates. fitness_modifier: Takes the `population` on which :meth:`~.Selector.select` is called and returns a :class:`dict`, which maps records in the `population` to the fitness values the :class:`.Selector` should use. random_seed: The random seed to use. """ super().__init__(key_maker, fitness_modifier, batch_size) if random_seed is None or isinstance(random_seed, int): random_seed = np.random.default_rng(random_seed) self._generator = random_seed self._duplicate_molecules = duplicate_molecules self._duplicate_batches = duplicate_batches self._num_batches = ( float("inf") if num_batches is None else num_batches ) def _select_from_batches( self, batches: Sequence[Batch[T]], yielded_batches: YieldedBatches[T], ) -> Iterator[Batch[T]]: # The tournament can only take place if there is more than 1 # batch. while ( len(batches) > 1 and yielded_batches.get_num() < self._num_batches ): tournament_size = self._generator.integers( low=2, high=len(batches) + 1 ) competitors = self._generator.choice( a=batches, # type: ignore size=tournament_size, replace=False, ) yield max(competitors) if not self._duplicate_molecules: batches_ = filter( yielded_batches.has_no_yielded_molecules, batches, ) if not self._duplicate_batches: batches_ = filter( yielded_batches.is_unyielded_batch, batches, ) if not self._duplicate_molecules or not self._duplicate_batches: batches = tuple(batches_)