Source code for stk.ea.selection.selectors.stochastic_universal_sampling

Stochastic Universal Sampling


import numpy as np

from stk.molecular import Inchi

from .selector import Selector

[docs]class StochasticUniversalSampling(Selector): """ Yields batches of molecules through stochastic universal sampling. Stochastic universal sampling lays out batches along a line, with each batch taking up length proportional to its fitness. It then creates a set of evenly spaced pointers to different points on the line, each of which is occupied by a batch. Batches which are pointed to are yielded. This approach means weaker members of the population are given a greater chance to be chosen than in :class:`.Roulette` selection [#]_. References ---------- .. [#] Examples -------- *Yielding Single Molecule Batches* Yielding molecules one at a time. For example, if molecules need to be selected for mutation or the next generation. .. testcode:: yielding-single-molecule-batches import stk # Make the selector. stochastic_sampling = stk.StochasticUniversalSampling(5) population = tuple( stk.MoleculeRecord( topology_graph=stk.polymer.Linear( building_blocks=( stk.BuildingBlock( smiles='BrCCBr', functional_groups=[stk.BromoFactory()], ), ), repeating_unit='A', num_repeating_units=2, ), ).with_fitness_value(i) for i in range(100) ) # Select the molecules. for selected, in # Do stuff with each selected molecule. pass """
[docs] def __init__( self, num_batches=None, batch_size=1, duplicate_molecules=True, duplicate_batches=True, key_maker=Inchi(), fitness_modifier=None, random_seed=None, ): """ Initialize a :class:`.StochasticUniversalSampling` instance. Parameters ---------- num_batches : :class:`int`, optional The number of batches to yield. If ``None`` then yielding will continue forever or until the generator is exhausted, whichever comes first. batch_size : :class:`int`, optional The number of molecules yielded at once. duplicate_molecules : :class:`bool`, optional If ``True`` the same molecule can be yielded in more than one batch. duplicate_batches : :class:`bool`, optional If ``True`` the same batch can be yielded more than once. key_maker : :class:`.MoleculeKeyMaker`, optional Used to get the keys of molecules. If two molecules have the same key, they are considered duplicates. fitness_modifier : :class:`callable`, optional Takes the `population` on which :meth:`.select` is called and returns a :class:`dict`, which maps records in the `population` to the fitness values the :class:`.Selector` should use. If ``None``, the regular fitness values of the records are used. random_seed : :class:`int`, optional The random seed to use. """ if fitness_modifier is None: fitness_modifier = self._get_fitness_values if num_batches is None: num_batches = float("inf") self._generator = np.random.RandomState(random_seed) self._duplicate_molecules = duplicate_molecules self._duplicate_batches = duplicate_batches self._num_batches = num_batches self._batch_size = batch_size super().__init__( key_maker=key_maker, fitness_modifier=fitness_modifier, )
def _select_from_batches(self, batches, yielded_batches): batches = sorted(batches, reverse=True) # SUS may need to run multiple rounds if duplicate_molecules or # duplicate_batches is False. This is because in each round # you can generate multiple pointers to the same batch or to # batches sharing molecules. If this happens the lower fitness # batch will not be yielded. Instead a second round of SUS will # occur with any ineligible batches removed and a reduced # number of pointers, to account for batches yielded in the # previous rounds. This will repeat until the desired number # of batches has been yielded, or there are no more valid # batches. while ( batches and yielded_batches.get_num() < self._num_batches ): yield from self._select_with_stochastic_universal_sampling( batches=batches, yielded_batches=yielded_batches, ) if yielded_batches.get_num() < self._num_batches: if not self._duplicate_molecules: batches = filter( yielded_batches.has_no_yielded_molecules, batches, ) if not self._duplicate_batches: batches = filter( yielded_batches.is_unyielded_batch, batches, ) if ( not self._duplicate_molecules or not self._duplicate_batches ): batches = tuple(batches) def _select_with_stochastic_universal_sampling( self, batches, yielded_batches, ): total = sum(batch.get_fitness_value() for batch in batches) batch_positions = [] batch_position = 0 for batch in batches: batch_position += batch.get_fitness_value() / total batch_positions.append(batch_position) num_batches = min( self._num_batches - yielded_batches.get_num(), len(batches) ) pointer_distance = 1 / num_batches pointers = [] pointer = self._generator.uniform(0, pointer_distance) for i in range(num_batches): pointers.append(pointer) pointer += pointer_distance batch_index = 0 for pointer in pointers: while pointer > batch_positions[batch_index]: batch_index += 1 batch = batches[batch_index] if ( not self._duplicate_molecules and yielded_batches.has_yielded_molecules(batch) ): continue if ( not self._duplicate_batches and yielded_batches.is_yielded_batch(batch) ): continue yield batch