blob: 221789ec5a4853d1ed8f4e2275757f7b0e3962c1 [file] [log] [blame]
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import enum
from typing import Generic, Iterable, TypeVar
class UnexpectedStateError(RuntimeError):
def __init__(self, state: BaseState, expected: Iterable[BaseState]) -> None:
self._state = state
self._expected = tuple(expected)
names = ", ".join(tuple(s.name for s in expected))
super().__init__(f"Unexpected state got={state.name} expected=({names})")
@property
def state(self) -> BaseState:
return self._state
@property
def expected(self) -> tuple[BaseState, ...]:
return self._expected
class BaseState(enum.IntEnum):
"""Base class for StateMachine states."""
@enum.unique
class State(BaseState):
"""Default state implementation."""
INITIAL = enum.auto()
SETUP = enum.auto()
READY = enum.auto()
RUN = enum.auto()
DONE = enum.auto()
StateT = TypeVar("StateT", bound="BaseState")
class StateMachine(Generic[StateT]):
def __init__(self, default: StateT) -> None:
self._state: StateT = default
@property
def state(self) -> StateT:
return self._state
@property
def name(self) -> str:
return self._state.name
def __hash__(self) -> int:
return hash(self._state)
def __eq__(self, other: object) -> bool:
if self is other:
return True
if isinstance(other, StateMachine):
return self._state is other._state
if isinstance(other, type(self._state)):
return self._state is other
return False
def transition(self, *args: StateT, to: StateT) -> None:
self.expect(*args)
self._state = to
def expect(self, *args: StateT) -> None:
if self._state not in args:
raise UnexpectedStateError(self._state, args)
def expect_before(self, state: StateT) -> None:
if self._state >= state:
valid_states = (s for s in type(self._state) if s < state)
raise UnexpectedStateError(self._state, valid_states)
def expect_at_least(self, state: StateT) -> None:
if self._state < state:
valid_states = (s for s in type(self._state) if s >= state)
raise UnexpectedStateError(self._state, valid_states)
def __str__(self) -> str:
return self._state.name