Source code for stk._internal.ea.evolutionary_algorithm.evolutionary_algorithm

import logging
import typing
from collections.abc import Iterable, Iterator

from stk._internal.ea.crossover.molecule_crosser import MoleculeCrosser
from stk._internal.ea.fitness_calculators.fitness_calculator import (
    FitnessCalculator,
)
from stk._internal.ea.fitness_normalizers.fitness_normalizer import (
    FitnessNormalizer,
)
from stk._internal.ea.generation import Generation
from stk._internal.ea.molecule_record import MoleculeRecord
from stk._internal.ea.mutation.mutator import MoleculeMutator
from stk._internal.ea.selection.selectors.selector import Selector
from stk._internal.key_makers.inchi import Inchi
from stk._internal.key_makers.molecule import MoleculeKeyMaker

from ..fitness_normalizers.null import NullFitnessNormalizer
from .implementations.parallel import Parallel
from .implementations.serial import Serial

logger = logging.getLogger(__name__)

T = typing.TypeVar("T", bound=MoleculeRecord)


[docs] class EvolutionaryAlgorithm(typing.Generic[T]): """ An abstract base class for evolutionary algorithms. Notes: You might notice that the public methods of this abstract base class are implemented. This is purely for convenience, so that there is a default evolutionary algorithm implementation that users can use. However, feel free to override the default implementation when implementing subclasses. If you do want to use the default implementation, here is a summary of the roles of the different components: .. image:: https://i.imgur.com/hGXboaU.png Examples: *Subclass Implementation* The source code of this class can work as a good example. There is only one method that a subclass of :class:`.EvolutionaryAlgorithm` needs to implement, :meth:`.get_generations`, which yields :class:`.Generation` instances. These correspond to the generations of your evolutionary algorithm implementation. *Usage* There are a couple of tutorials on how to use the :class:`.EvolutionaryAlgorithm`, which can be found in the sidebar. """ _implementation: Serial | Parallel def __init__( self, initial_population: Iterable[T], fitness_calculator: FitnessCalculator[T], mutator: MoleculeMutator[T], crosser: MoleculeCrosser[T], generation_selector: Selector[T], mutation_selector: Selector[T], crossover_selector: Selector[T], fitness_normalizer: FitnessNormalizer[T] = NullFitnessNormalizer(), key_maker: MoleculeKeyMaker = Inchi(), num_processes: int | None = None, ) -> None: """ Parameters: initial_population (list[T]): The initial population the EA should use. fitness_calculator: Calculates fitness values. mutator: Carries out mutation operations. crosser: Carries out crossover operations. generation_selector: Selects the next generation. mutation_selector: Selects molecules for mutation. crossover_selector: Selects molecules for crossover. fitness_normalizer: Normalizes fitness values. key_maker: Used to detect duplicate molecules in the EA. If two molecules in a generation return the same key, one of them is removed. num_processes: The number of parallel processes the EA should create. If ``None``, all available cores will be used. """ if num_processes == 1: self._implementation = Serial( initial_population=initial_population, fitness_calculator=fitness_calculator, mutator=mutator, crosser=crosser, generation_selector=generation_selector, mutation_selector=mutation_selector, crossover_selector=crossover_selector, fitness_normalizer=fitness_normalizer, key_maker=key_maker, logger=logger, ) else: self._implementation = Parallel( initial_population=initial_population, fitness_calculator=fitness_calculator, mutator=mutator, crosser=crosser, generation_selector=generation_selector, mutation_selector=mutation_selector, crossover_selector=crossover_selector, fitness_normalizer=fitness_normalizer, key_maker=key_maker, logger=logger, num_processes=num_processes, )
[docs] def get_generations(self, num_generations: int) -> Iterator[Generation[T]]: """ Yield the generations of the evolutionary algorithm. Parameters: num_generations: The number of generations which should be yielded. Note that the initial population counts as a generation. Yields: A generation. """ yield from self._implementation.get_generations( num_generations=num_generations, )