Source code for pandora2d.state_machine

#!/usr/bin/env python
#
# Copyright (c) 2026 Centre National d'Etudes Spatiales (CNES).
# Copyright (c) 2026 CS GROUP France
#
# This file is part of PANDORA2D
#
#     https://github.com/CNES/Pandora2D
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains class associated to the pandora state machine
"""

import copy
import logging
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import TYPE_CHECKING, Annotated, Literal, TypedDict

import xarray as xr
from pandora.margins import GlobalMargins
from transitions import MachineError

from pandora2d import common, criteria, disparity, estimation, img_tools, refinement
from pandora2d.cost_volume_confidence import CostVolumeConfidenceRegistry
from pandora2d.matching_cost import BaseMatchingCost, MatchingCostRegistry
from pandora2d.profiling import mem_time_profile

if find_spec("graphviz") is None or TYPE_CHECKING:
    # Condition on TYPE_CHECKING is here to make Mypy believe we always import Machine in order de avoid this message
    # from Mypy:
    # Incompatible import of "Machine" \
    # (imported name has type "type[Machine]", local name has type "type[GraphMachine]")
    from transitions import Machine  # pylint: disable=ungrouped-imports
else:
    from transitions.extensions import GraphMachine

[docs] class Machine(GraphMachine): """A GraphMachine which defaults to graphviz engine.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs, graph_engine="graphviz")
[docs] class MarginsProperties(TypedDict): """Properties of Margins used in Margins transitions."""
[docs] type: Literal["aggregate", "maximum"]
[docs] margins: Annotated[list[int], '["left, "up", "right", "down"]']
[docs] class BaseMachine(Machine, ABC): """Base model and state machine for pandora2d.""" def __init__(self) -> None: # For communication between matching_cost and refinement steps
[docs] self.step: list | None = None
[docs] self.pipeline_cfg: dict = {"pipeline": {}}
[docs] self.window_size: int | None = None
[docs] self.margins_img = GlobalMargins()
[docs] self.margins_disp = GlobalMargins()
# Define available states states = ["begin", "assumption", "cost_volumes", "disparity_map"] transitions = [ { "trigger": "estimation", "source": "begin", "dest": "assumption", "before": "estimation_run", }, { "trigger": "matching_cost", "source": "begin", "dest": "cost_volumes", "before": "matching_cost_run", }, { "trigger": "matching_cost", "source": "assumption", "dest": "cost_volumes", "before": "matching_cost_run", }, { "trigger": "cost_volume_confidence", "source": "cost_volumes", "dest": "cost_volumes", "before": "cost_volume_confidence_run", }, { "trigger": "disparity", "source": "cost_volumes", "dest": "disparity_map", "before": "disparity_run", }, { "trigger": "refinement", "source": "disparity_map", "dest": "disparity_map", "before": "refinement_run", }, ] # Initialize a machine Machine.__init__( self, states=states, initial="begin", transitions=transitions, auto_transitions=False, ) logging.getLogger("transitions").setLevel(logging.WARNING) @abstractmethod
[docs] def check_conf(self, cfg): """ Check configuration and transitions :param cfg: pipeline configuration :return: """
@abstractmethod
[docs] def estimation_run(self, cfg, input_step): """ Estimation's computation step. :param cfg: configuration :param input_step: current step :return: None """
@abstractmethod
[docs] def matching_cost_run(self, cfg, input_step): """ Matching cost computation step. :param cfg: configuration :param input_step: current step :return: None """
@abstractmethod
[docs] def cost_volume_confidence_run(self, cfg, input_step): """ Cost volume confidence's computation. :param cfg: configuration :param input_step: current step :return: None """
@abstractmethod
[docs] def disparity_run(self, cfg, input_step): """ Disparity's computation. :param cfg: configuration :param input_step: current step :return: None """
@abstractmethod
[docs] def refinement_run(self, cfg, input_step): """ Refinement's configuration. :param cfg: configuration :param input_step: current step :return: None """
[docs] class CheckMachine(BaseMachine): """State Machine that checks Pandora2d configuration."""
[docs] def check_conf(self, cfg: dict[str, dict]) -> None: """ Check configuration and transitions :param cfg: pipeline configuration :return: """ for input_step in list(cfg["pipeline"]): try: self.trigger(input_step, cfg, input_step) except (MachineError, KeyError, AttributeError): logging.error("Problem occurs during %s check. Be sure of your sequencement step", input_step) raise
[docs] def estimation_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Check the estimation computation configuration :param cfg: configuration :param input_step: current step :return: None """ estimation_ = estimation.AbstractEstimation(cfg["pipeline"][input_step]) # type: ignore[abstract] self.pipeline_cfg["pipeline"][input_step] = estimation_.cfg
[docs] def matching_cost_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Check the matching cost computation configuration :param cfg: configuration :param input_step: current step :return: None """ MatchingCost = MatchingCostRegistry.get( # pylint:disable=invalid-name # NOSONAR cfg["pipeline"][input_step]["matching_cost_method"] ) matching_cost = MatchingCost(cfg["pipeline"][input_step]) self.pipeline_cfg["pipeline"][input_step] = matching_cost.cfg self.step = matching_cost.step self.window_size = matching_cost.window_size self.margins_img.add_cumulative(input_step, matching_cost.margins)
[docs] def cost_volume_confidence_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Check the cost volume confidence computation configuration :param cfg: configuration :param input_step: current step :return: None """ CostVolumeConfidence = CostVolumeConfidenceRegistry.get( # pylint:disable=invalid-name # NOSONAR cfg["pipeline"][input_step]["confidence_method"] ) cost_volume_confidence = CostVolumeConfidence(cfg["pipeline"][input_step]) self.pipeline_cfg["pipeline"][input_step] = cost_volume_confidence.cfg
[docs] def disparity_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Check the disparity computation configuration :param cfg: configuration :param input_step: current step :return: None """ disparity_ = disparity.Disparity(cfg["pipeline"][input_step]) self.pipeline_cfg["pipeline"][input_step] = disparity_.cfg self.margins_img.add_cumulative(input_step, disparity_.margins)
[docs] def refinement_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Check the refinement configuration :param cfg: configuration :param input_step: current step :return: None """ refinement_ = refinement.AbstractRefinement( cfg["pipeline"][input_step], self.step, self.window_size ) # type: ignore[abstract] self.pipeline_cfg["pipeline"][input_step] = refinement_.cfg self.margins_disp.add_non_cumulative(input_step, refinement_.margins) self.margins_img.add_cumulative(input_step, refinement_.margins)
[docs] class Pandora2DMachine(BaseMachine): """ Pandora2DMachine class to create and use a state machine """ def __init__( self, ) -> None: """ Initialize Pandora2D Machine """ super().__init__() for transition in self.get_transitions("matching_cost"): transition.add_callback("prepare", self.matching_cost_prepare) # Left image
[docs] self.left_img: xr.Dataset | None = None
# Right image
[docs] self.right_img: xr.Dataset | None = None
[docs] self.completed_cfg: dict = {}
[docs] self.cost_volumes: xr.Dataset = xr.Dataset()
[docs] self.dataset_disp_maps: xr.Dataset = xr.Dataset()
# Instance matching_cost
[docs] self.matching_cost_: BaseMatchingCost | None = None
[docs] def run_prepare(self, img_left: xr.Dataset, img_right: xr.Dataset, cfg: dict) -> None: """ Prepare the machine before running :param img_left: left Dataset image containing : - im : 2D (row, col) xarray.DataArray - msk : 2D (row, col) xarray.DataArray :param img_right: right Dataset image containing : - im : 2D (row, col) xarray.DataArray - msk : 2D (row, col) xarray.DataArray :param cfg: configuration """ self.left_img = img_left self.right_img = img_right self.completed_cfg = copy.copy(cfg)
[docs] def run(self, input_step: str, cfg: dict[str, dict]) -> None: """ Run pandora 2D step by triggering the corresponding machine transition :param input_step: step to trigger :param cfg: pipeline configuration :return: None """ try: if len(input_step.split(".")) != 1: self.trigger(input_step.split(".")[0], cfg, input_step) else: self.trigger(input_step, cfg, input_step) except (MachineError, KeyError, AttributeError): logging.error("Problem occurs during Pandora2D running %s. Be sure of your sequencement step", input_step) raise
[docs] def run_exit(self) -> None: """ Clear transitions and return to state begin :return: None """ self.set_state("begin")
[docs] def check_conf(self, cfg: dict[str, dict]) -> None: """ Check configuration and transitions :param cfg: pipeline configuration :return: """ check_machine = CheckMachine() check_machine.check_conf(cfg) self.step = copy.deepcopy(check_machine.step) self.window_size = check_machine.window_size self.pipeline_cfg = copy.deepcopy(check_machine.pipeline_cfg) self.margins_img = copy.deepcopy(check_machine.margins_img) self.margins_disp = copy.deepcopy(check_machine.margins_disp) # Coming back to the initial state self.set_state("begin")
[docs] def remove_transitions(self, transition_list: dict[str, dict]) -> None: """ Delete all transitions defined in the input list :param transition_list: list of transitions :return: None """ # Transition is removed using trigger name. But one trigger name can be used by multiple transitions # In this case, the "remove_transition" function removes all transitions using this trigger name # deleted_triggers list is used to avoid multiple call of "remove_transition" with the same trigger name. deleted_triggers = [] for trans in transition_list: if trans["trigger"] not in deleted_triggers: # type: ignore self.remove_transition(trans["trigger"]) # type: ignore deleted_triggers.append(trans["trigger"]) # type: ignore
[docs] def matching_cost_prepare(self, cfg: dict[str, dict], input_step: str) -> None: """ Matching cost prepare :param cfg: pipeline configuration :param input_step: step to trigger :return: None """ MatchingCost = MatchingCostRegistry.get( # pylint:disable=invalid-name # NOSONAR cfg["pipeline"][input_step]["matching_cost_method"] ) self.matching_cost_ = MatchingCost(cfg["pipeline"][input_step]) self.matching_cost_.allocate(self.left_img, self.right_img, cfg, self.margins_disp.get("refinement")) # Compute validity dataset dataset_validity = criteria.get_validity_dataset( self.matching_cost_.cost_volumes["criteria"], self.matching_cost_.cost_volumes.attrs["row_disparity_source"], self.matching_cost_.cost_volumes.attrs["col_disparity_source"], ) # Check if cost volume confidence step is in the pipeline cost_volume_confidence_step = "cost_volume_confidence" in cfg["pipeline"] # Allocate disparity maps dataset self.dataset_disp_maps = common.dataset_disp_maps( self.matching_cost_.cost_volumes.coords, dataset_validity, { "origin_coordinates": { "row": cfg.get("ROI", {}).get("row", {}).get("first", 0), "col": cfg.get("ROI", {}).get("col", {}).get("first", 0), }, "step": { "row": cfg["pipeline"]["matching_cost"]["step"][0], "col": cfg["pipeline"]["matching_cost"]["step"][1], }, "invalid_disp": cfg["pipeline"]["disparity"]["invalid_disparity"], "crs": self.left_img.crs, "transform": self.left_img.transform, }, self.matching_cost_.cost_volumes["cost_volumes"].dtype, cost_volume_confidence_step, )
@mem_time_profile(name="Estimation step")
[docs] def estimation_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Shift's estimation step :param cfg: pipeline configuration :param input_step: step to trigger :return: None """ logging.info("Estimation computation...") estimation_ = estimation.AbstractEstimation(cfg["pipeline"][input_step]) # type: ignore[abstract] row_disparity, col_disparity, shifts, extra_dict = estimation_.compute_estimation(self.left_img, self.right_img) self.completed_cfg = estimation_.update_cfg_with_estimation( cfg, col_disparity, row_disparity, shifts, extra_dict ) # Update ROI margins with correct disparities roi = None if "ROI" in cfg: roi = img_tools.get_roi_processing(cfg["ROI"], cfg["input"]["col_disparity"], cfg["input"]["row_disparity"]) # Recreate left and right image datasets with correct disparities and ROI margins self.left_img, self.right_img = img_tools.create_datasets_from_inputs( input_config=cfg["input"], roi=roi, estimation_cfg=None ) else: # Update disparities for left and right image datasets self.left_img = img_tools.add_disparity_grid(self.left_img, col_disparity, row_disparity) self.right_img = img_tools.add_disparity_grid(self.right_img, col_disparity, row_disparity)
@mem_time_profile(name="Matching cost step")
[docs] def matching_cost_run(self, _, __) -> None: """ Matching cost computation :return: None """ logging.info("Matching cost computation...") self.cost_volumes = self.matching_cost_.compute_cost_volumes( self.left_img, self.right_img, self.margins_disp.get("refinement"), )
@mem_time_profile(name="Cost volume confidence step")
[docs] def cost_volume_confidence_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Cost volume confidence computation :return: None """ logging.info("Cost volume confidence computation...") CostVolumeConfidence = CostVolumeConfidenceRegistry.get( # pylint:disable=invalid-name # NOSONAR cfg["pipeline"][input_step]["confidence_method"] ) confidence_ = CostVolumeConfidence(cfg["pipeline"][input_step]) self.cost_volumes, self.dataset_disp_maps = confidence_.confidence_prediction( self.left_img, self.cost_volumes, self.dataset_disp_maps, )
@mem_time_profile(name="Disparity step")
[docs] def disparity_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Disparity computation and validity mask :param cfg: pipeline configuration :param input_step: step to trigger :return: None """ logging.info("Disparity computation...") disparity_ = disparity.Disparity(cfg["pipeline"][input_step]) map_col, map_row, correlation_score = disparity_.compute_disp_maps(self.cost_volumes) common.complete_dataset_disp_maps(self.dataset_disp_maps, map_row, map_col, correlation_score) cv_coords = (self.cost_volumes.row.values, self.cost_volumes.col.values) criteria.apply_peak_on_edge( self.dataset_disp_maps["validity"], self.left_img, cv_coords, self.dataset_disp_maps["row_map"].data, self.dataset_disp_maps["col_map"].data, )
@mem_time_profile(name="Refinement step")
[docs] def refinement_run(self, cfg: dict[str, dict], input_step: str) -> None: """ Subpixel disparity refinement :param cfg: pipeline configuration :param input_step: step to trigger :return: None """ logging.info("Refinement computation...") if cfg["pipeline"][input_step]["refinement_method"] == "optical_flow": logging.warning("The optical flow method is still in an experimental phase.") logging.warning("The correlation score map is at a disparity level for the optical flow method.") refinement_run = refinement.AbstractRefinement( cfg["pipeline"][input_step], self.step, self.window_size ) # type: ignore[abstract] refine_map_col, refine_map_row, correlation_score = refinement_run.refinement_method( self.cost_volumes, self.dataset_disp_maps, self.left_img, self.right_img ) common.complete_dataset_disp_maps(self.dataset_disp_maps, refine_map_row, refine_map_col, correlation_score)