#!/usr/bin/env python
#
# Copyright (c) 2026 Centre National d'Etudes Spatiales (CNES).
# Copyright (c) 2026 CS GROUP France
#
# This file is part of PANDORA2D
#
# https://github.com/CNES/Pandora2D
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains class associated to the pandora state machine
"""
import copy
import logging
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import TYPE_CHECKING, Annotated, Literal, TypedDict
import xarray as xr
from pandora.margins import GlobalMargins
from transitions import MachineError
from pandora2d import common, criteria, disparity, estimation, img_tools, refinement
from pandora2d.cost_volume_confidence import CostVolumeConfidenceRegistry
from pandora2d.matching_cost import BaseMatchingCost, MatchingCostRegistry
from pandora2d.profiling import mem_time_profile
if find_spec("graphviz") is None or TYPE_CHECKING:
# Condition on TYPE_CHECKING is here to make Mypy believe we always import Machine in order de avoid this message
# from Mypy:
# Incompatible import of "Machine" \
# (imported name has type "type[Machine]", local name has type "type[GraphMachine]")
from transitions import Machine # pylint: disable=ungrouped-imports
else:
from transitions.extensions import GraphMachine
[docs]
class Machine(GraphMachine):
"""A GraphMachine which defaults to graphviz engine."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs, graph_engine="graphviz")
[docs]
class MarginsProperties(TypedDict):
"""Properties of Margins used in Margins transitions."""
[docs]
type: Literal["aggregate", "maximum"]
[docs]
margins: Annotated[list[int], '["left, "up", "right", "down"]']
[docs]
class BaseMachine(Machine, ABC):
"""Base model and state machine for pandora2d."""
def __init__(self) -> None:
# For communication between matching_cost and refinement steps
[docs]
self.step: list | None = None
[docs]
self.pipeline_cfg: dict = {"pipeline": {}}
[docs]
self.window_size: int | None = None
[docs]
self.margins_img = GlobalMargins()
[docs]
self.margins_disp = GlobalMargins()
# Define available states
states = ["begin", "assumption", "cost_volumes", "disparity_map"]
transitions = [
{
"trigger": "estimation",
"source": "begin",
"dest": "assumption",
"before": "estimation_run",
},
{
"trigger": "matching_cost",
"source": "begin",
"dest": "cost_volumes",
"before": "matching_cost_run",
},
{
"trigger": "matching_cost",
"source": "assumption",
"dest": "cost_volumes",
"before": "matching_cost_run",
},
{
"trigger": "cost_volume_confidence",
"source": "cost_volumes",
"dest": "cost_volumes",
"before": "cost_volume_confidence_run",
},
{
"trigger": "disparity",
"source": "cost_volumes",
"dest": "disparity_map",
"before": "disparity_run",
},
{
"trigger": "refinement",
"source": "disparity_map",
"dest": "disparity_map",
"before": "refinement_run",
},
]
# Initialize a machine
Machine.__init__(
self,
states=states,
initial="begin",
transitions=transitions,
auto_transitions=False,
)
logging.getLogger("transitions").setLevel(logging.WARNING)
@abstractmethod
[docs]
def check_conf(self, cfg):
"""
Check configuration and transitions
:param cfg: pipeline configuration
:return:
"""
@abstractmethod
[docs]
def estimation_run(self, cfg, input_step):
"""
Estimation's computation step.
:param cfg: configuration
:param input_step: current step
:return: None
"""
@abstractmethod
[docs]
def matching_cost_run(self, cfg, input_step):
"""
Matching cost computation step.
:param cfg: configuration
:param input_step: current step
:return: None
"""
@abstractmethod
[docs]
def cost_volume_confidence_run(self, cfg, input_step):
"""
Cost volume confidence's computation.
:param cfg: configuration
:param input_step: current step
:return: None
"""
@abstractmethod
[docs]
def disparity_run(self, cfg, input_step):
"""
Disparity's computation.
:param cfg: configuration
:param input_step: current step
:return: None
"""
@abstractmethod
[docs]
def refinement_run(self, cfg, input_step):
"""
Refinement's configuration.
:param cfg: configuration
:param input_step: current step
:return: None
"""
[docs]
class CheckMachine(BaseMachine):
"""State Machine that checks Pandora2d configuration."""
[docs]
def check_conf(self, cfg: dict[str, dict]) -> None:
"""
Check configuration and transitions
:param cfg: pipeline configuration
:return:
"""
for input_step in list(cfg["pipeline"]):
try:
self.trigger(input_step, cfg, input_step)
except (MachineError, KeyError, AttributeError):
logging.error("Problem occurs during %s check. Be sure of your sequencement step", input_step)
raise
[docs]
def estimation_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Check the estimation computation configuration
:param cfg: configuration
:param input_step: current step
:return: None
"""
estimation_ = estimation.AbstractEstimation(cfg["pipeline"][input_step]) # type: ignore[abstract]
self.pipeline_cfg["pipeline"][input_step] = estimation_.cfg
[docs]
def matching_cost_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Check the matching cost computation configuration
:param cfg: configuration
:param input_step: current step
:return: None
"""
MatchingCost = MatchingCostRegistry.get( # pylint:disable=invalid-name # NOSONAR
cfg["pipeline"][input_step]["matching_cost_method"]
)
matching_cost = MatchingCost(cfg["pipeline"][input_step])
self.pipeline_cfg["pipeline"][input_step] = matching_cost.cfg
self.step = matching_cost.step
self.window_size = matching_cost.window_size
self.margins_img.add_cumulative(input_step, matching_cost.margins)
[docs]
def cost_volume_confidence_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Check the cost volume confidence computation configuration
:param cfg: configuration
:param input_step: current step
:return: None
"""
CostVolumeConfidence = CostVolumeConfidenceRegistry.get( # pylint:disable=invalid-name # NOSONAR
cfg["pipeline"][input_step]["confidence_method"]
)
cost_volume_confidence = CostVolumeConfidence(cfg["pipeline"][input_step])
self.pipeline_cfg["pipeline"][input_step] = cost_volume_confidence.cfg
[docs]
def disparity_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Check the disparity computation configuration
:param cfg: configuration
:param input_step: current step
:return: None
"""
disparity_ = disparity.Disparity(cfg["pipeline"][input_step])
self.pipeline_cfg["pipeline"][input_step] = disparity_.cfg
self.margins_img.add_cumulative(input_step, disparity_.margins)
[docs]
def refinement_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Check the refinement configuration
:param cfg: configuration
:param input_step: current step
:return: None
"""
refinement_ = refinement.AbstractRefinement(
cfg["pipeline"][input_step], self.step, self.window_size
) # type: ignore[abstract]
self.pipeline_cfg["pipeline"][input_step] = refinement_.cfg
self.margins_disp.add_non_cumulative(input_step, refinement_.margins)
self.margins_img.add_cumulative(input_step, refinement_.margins)
[docs]
class Pandora2DMachine(BaseMachine):
"""
Pandora2DMachine class to create and use a state machine
"""
def __init__(
self,
) -> None:
"""
Initialize Pandora2D Machine
"""
super().__init__()
for transition in self.get_transitions("matching_cost"):
transition.add_callback("prepare", self.matching_cost_prepare)
# Left image
[docs]
self.left_img: xr.Dataset | None = None
# Right image
[docs]
self.right_img: xr.Dataset | None = None
[docs]
self.completed_cfg: dict = {}
[docs]
self.cost_volumes: xr.Dataset = xr.Dataset()
[docs]
self.dataset_disp_maps: xr.Dataset = xr.Dataset()
# Instance matching_cost
[docs]
self.matching_cost_: BaseMatchingCost | None = None
[docs]
def run_prepare(self, img_left: xr.Dataset, img_right: xr.Dataset, cfg: dict) -> None:
"""
Prepare the machine before running
:param img_left: left Dataset image containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param img_right: right Dataset image containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param cfg: configuration
"""
self.left_img = img_left
self.right_img = img_right
self.completed_cfg = copy.copy(cfg)
[docs]
def run(self, input_step: str, cfg: dict[str, dict]) -> None:
"""
Run pandora 2D step by triggering the corresponding machine transition
:param input_step: step to trigger
:param cfg: pipeline configuration
:return: None
"""
try:
if len(input_step.split(".")) != 1:
self.trigger(input_step.split(".")[0], cfg, input_step)
else:
self.trigger(input_step, cfg, input_step)
except (MachineError, KeyError, AttributeError):
logging.error("Problem occurs during Pandora2D running %s. Be sure of your sequencement step", input_step)
raise
[docs]
def run_exit(self) -> None:
"""
Clear transitions and return to state begin
:return: None
"""
self.set_state("begin")
[docs]
def check_conf(self, cfg: dict[str, dict]) -> None:
"""
Check configuration and transitions
:param cfg: pipeline configuration
:return:
"""
check_machine = CheckMachine()
check_machine.check_conf(cfg)
self.step = copy.deepcopy(check_machine.step)
self.window_size = check_machine.window_size
self.pipeline_cfg = copy.deepcopy(check_machine.pipeline_cfg)
self.margins_img = copy.deepcopy(check_machine.margins_img)
self.margins_disp = copy.deepcopy(check_machine.margins_disp)
# Coming back to the initial state
self.set_state("begin")
[docs]
def remove_transitions(self, transition_list: dict[str, dict]) -> None:
"""
Delete all transitions defined in the input list
:param transition_list: list of transitions
:return: None
"""
# Transition is removed using trigger name. But one trigger name can be used by multiple transitions
# In this case, the "remove_transition" function removes all transitions using this trigger name
# deleted_triggers list is used to avoid multiple call of "remove_transition" with the same trigger name.
deleted_triggers = []
for trans in transition_list:
if trans["trigger"] not in deleted_triggers: # type: ignore
self.remove_transition(trans["trigger"]) # type: ignore
deleted_triggers.append(trans["trigger"]) # type: ignore
[docs]
def matching_cost_prepare(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Matching cost prepare
:param cfg: pipeline configuration
:param input_step: step to trigger
:return: None
"""
MatchingCost = MatchingCostRegistry.get( # pylint:disable=invalid-name # NOSONAR
cfg["pipeline"][input_step]["matching_cost_method"]
)
self.matching_cost_ = MatchingCost(cfg["pipeline"][input_step])
self.matching_cost_.allocate(self.left_img, self.right_img, cfg, self.margins_disp.get("refinement"))
# Compute validity dataset
dataset_validity = criteria.get_validity_dataset(
self.matching_cost_.cost_volumes["criteria"],
self.matching_cost_.cost_volumes.attrs["row_disparity_source"],
self.matching_cost_.cost_volumes.attrs["col_disparity_source"],
)
# Check if cost volume confidence step is in the pipeline
cost_volume_confidence_step = "cost_volume_confidence" in cfg["pipeline"]
# Allocate disparity maps dataset
self.dataset_disp_maps = common.dataset_disp_maps(
self.matching_cost_.cost_volumes.coords,
dataset_validity,
{
"origin_coordinates": {
"row": cfg.get("ROI", {}).get("row", {}).get("first", 0),
"col": cfg.get("ROI", {}).get("col", {}).get("first", 0),
},
"step": {
"row": cfg["pipeline"]["matching_cost"]["step"][0],
"col": cfg["pipeline"]["matching_cost"]["step"][1],
},
"invalid_disp": cfg["pipeline"]["disparity"]["invalid_disparity"],
"crs": self.left_img.crs,
"transform": self.left_img.transform,
},
self.matching_cost_.cost_volumes["cost_volumes"].dtype,
cost_volume_confidence_step,
)
@mem_time_profile(name="Estimation step")
[docs]
def estimation_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Shift's estimation step
:param cfg: pipeline configuration
:param input_step: step to trigger
:return: None
"""
logging.info("Estimation computation...")
estimation_ = estimation.AbstractEstimation(cfg["pipeline"][input_step]) # type: ignore[abstract]
row_disparity, col_disparity, shifts, extra_dict = estimation_.compute_estimation(self.left_img, self.right_img)
self.completed_cfg = estimation_.update_cfg_with_estimation(
cfg, col_disparity, row_disparity, shifts, extra_dict
)
# Update ROI margins with correct disparities
roi = None
if "ROI" in cfg:
roi = img_tools.get_roi_processing(cfg["ROI"], cfg["input"]["col_disparity"], cfg["input"]["row_disparity"])
# Recreate left and right image datasets with correct disparities and ROI margins
self.left_img, self.right_img = img_tools.create_datasets_from_inputs(
input_config=cfg["input"], roi=roi, estimation_cfg=None
)
else:
# Update disparities for left and right image datasets
self.left_img = img_tools.add_disparity_grid(self.left_img, col_disparity, row_disparity)
self.right_img = img_tools.add_disparity_grid(self.right_img, col_disparity, row_disparity)
@mem_time_profile(name="Matching cost step")
[docs]
def matching_cost_run(self, _, __) -> None:
"""
Matching cost computation
:return: None
"""
logging.info("Matching cost computation...")
self.cost_volumes = self.matching_cost_.compute_cost_volumes(
self.left_img,
self.right_img,
self.margins_disp.get("refinement"),
)
@mem_time_profile(name="Cost volume confidence step")
[docs]
def cost_volume_confidence_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Cost volume confidence computation
:return: None
"""
logging.info("Cost volume confidence computation...")
CostVolumeConfidence = CostVolumeConfidenceRegistry.get( # pylint:disable=invalid-name # NOSONAR
cfg["pipeline"][input_step]["confidence_method"]
)
confidence_ = CostVolumeConfidence(cfg["pipeline"][input_step])
self.cost_volumes, self.dataset_disp_maps = confidence_.confidence_prediction(
self.left_img,
self.cost_volumes,
self.dataset_disp_maps,
)
@mem_time_profile(name="Disparity step")
[docs]
def disparity_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Disparity computation and validity mask
:param cfg: pipeline configuration
:param input_step: step to trigger
:return: None
"""
logging.info("Disparity computation...")
disparity_ = disparity.Disparity(cfg["pipeline"][input_step])
map_col, map_row, correlation_score = disparity_.compute_disp_maps(self.cost_volumes)
common.complete_dataset_disp_maps(self.dataset_disp_maps, map_row, map_col, correlation_score)
cv_coords = (self.cost_volumes.row.values, self.cost_volumes.col.values)
criteria.apply_peak_on_edge(
self.dataset_disp_maps["validity"],
self.left_img,
cv_coords,
self.dataset_disp_maps["row_map"].data,
self.dataset_disp_maps["col_map"].data,
)
@mem_time_profile(name="Refinement step")
[docs]
def refinement_run(self, cfg: dict[str, dict], input_step: str) -> None:
"""
Subpixel disparity refinement
:param cfg: pipeline configuration
:param input_step: step to trigger
:return: None
"""
logging.info("Refinement computation...")
if cfg["pipeline"][input_step]["refinement_method"] == "optical_flow":
logging.warning("The optical flow method is still in an experimental phase.")
logging.warning("The correlation score map is at a disparity level for the optical flow method.")
refinement_run = refinement.AbstractRefinement(
cfg["pipeline"][input_step], self.step, self.window_size
) # type: ignore[abstract]
refine_map_col, refine_map_row, correlation_score = refinement_run.refinement_method(
self.cost_volumes, self.dataset_disp_maps, self.left_img, self.right_img
)
common.complete_dataset_disp_maps(self.dataset_disp_maps, refine_map_row, refine_map_col, correlation_score)