Source code for stk._internal.ea.generation

import typing
from collections.abc import Iterable, Iterator
from dataclasses import dataclass

from stk._internal.ea.crossover.record import CrossoverRecord
from stk._internal.ea.mutation.record import MutationRecord


[docs] @dataclass(frozen=True, slots=True) class FitnessValues: """ Fitness values of a molecule. Parameters: raw: Fitness value before normalization. normalized: Fitness value after normalization. """ raw: typing.Any """Fitness value before normalization.""" normalized: float """Fitness value after normalization."""
T = typing.TypeVar("T")
[docs] class Generation(typing.Generic[T]): """ An abstract base class for EA generations. Notes: You might notice that the public methods of this abstract base class are implemented. This is just a default implementation, which you can ignore or override when implementing subclasses. """ def __init__( self, fitness_values: "dict[T, FitnessValues]", mutation_records: Iterable[MutationRecord[T]], crossover_records: Iterable[CrossoverRecord[T]], ) -> None: """ Parameters: fitness_values: The records of molecules in the generation. mutation_records (list[MutationRecord[T]]): The records of mutations done during the generation. crossover_records (list[CrossoverRecord[T]]): The records of crossover operations done during the generation. """ self._fitness_values = dict(fitness_values) self._mutation_records = tuple(mutation_records) self._crossover_records = tuple(crossover_records)
[docs] def get_fitness_values(self) -> dict[T, FitnessValues]: """ Get the fitness values of the generation. Returns: The fitness values. """ return dict(self._fitness_values)
[docs] def get_molecule_records(self) -> Iterator[T]: """ Yield the molecule records in the generation. Yields: A molecule record. """ yield from self._fitness_values
[docs] def get_mutation_records(self) -> Iterator[MutationRecord[T]]: """ Yield the mutation records in the generation. Yields: A mutation record. """ yield from self._mutation_records
[docs] def get_crossover_records(self) -> Iterator[CrossoverRecord[T]]: """ Yield the crossover records in the generation. Yields: A crossover record. """ yield from self._crossover_records