# Copyright (c) 2026 Centre National d'Etudes Spatiales (CNES).
# Copyright (c) 2026 CS GROUP France
#
# This file is part of PANDORA2D
#
# https://github.com/CNES/Pandora2D
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains functions associated to the matching cost computation step.
"""
import copy
import numpy as np
import xarray as xr
from json_checker import And
from pandora import matching_cost
from pandora.margins import Margins as PandoraMargins
from pandora2d import img_tools
from pandora2d.common import build_usable_data_mask
from pandora2d.margins import Margins
from .base import BaseMatchingCost
[docs]
class PandoraMatchingCostMethods(BaseMatchingCost):
"""
Matching Cost class
"""
def __init__(self, cfg: dict) -> None:
"""
Initialisation of matching_cost class
:param cfg: user_config for matching cost
:return: None
"""
super().__init__(cfg)
# Check the matching_cost parameters specific to pandora
[docs]
self.pandora_matching_cost_ = matching_cost.AbstractMatchingCost( # type: ignore[abstract]
**self.get_config_for_pandora(cfg)
)
[docs]
self.grid: xr.Dataset | None = None
@property
[docs]
def cfg(self) -> dict[str, str | int | list[int]]:
"""
Get used configuration
:return: cfg: dictionary with all parameters
"""
return {
"matching_cost_method": self._method,
"step": self.step,
"window_size": self.window_size,
"subpix": self._subpix,
"spline_order": self._spline_order,
"float_precision": str(self._float_precision),
}
@property
[docs]
def schema(self):
schema = super().schema
schema.update(
{
"float_precision": And(str, lambda x: np.dtype(x) in [np.float32]),
}
)
return schema
[docs]
def check_conf(self, cfg: dict) -> dict[str, str]:
"""
Check the matching cost configuration
and raise error if float precision is float64
"""
if "float_precision" in cfg and np.dtype(cfg["float_precision"]) == np.float64:
raise ValueError(
"With sad, ssd, mc_cnn, and zncc_python methods, only the float32 type is accepted for float_precision"
)
return super().check_conf(cfg)
@property
[docs]
def window_size(self) -> int:
"""
Get window_size, parameter specific to pandora
:return: window_size: window used to compute correlation
"""
return self.pandora_matching_cost_._window_size # pylint: disable=W0212 protected-access
@window_size.setter
def window_size(self, value: int) -> None:
"""
Set window_size, parameter specific to pandora
:value: window used to compute correlation
"""
# Does nothing as we just want to override superclass behavior
@property
[docs]
def _subpix(self) -> int:
"""
Get subpix, parameter specific to pandora
:return: subpix: subpix used
"""
return self.pandora_matching_cost_._subpix # pylint: disable=W0212 protected-access
@_subpix.setter
def _subpix(self, value: int) -> int:
"""
Set subpix, parameter specific to pandora
:return: subpix: subpix used
"""
# Does nothing as we just want to override superclass behavior
@property
[docs]
def _spline_order(self) -> int:
"""
Get spline_order, parameter specific to pandora
:return: spline_order: spline_order used
"""
return self.pandora_matching_cost_._spline_order # pylint: disable=W0212 protected-access
@_spline_order.setter
def _spline_order(self, value: int) -> int:
"""
Set spline_order, parameter specific to pandora
:return: spline_order: spline_order used
"""
# Does nothing as we just want to override superclass behavior
@property
[docs]
def margins(self) -> PandoraMargins:
"""
Get margins from pandora correlation measurement
"""
return self.pandora_matching_cost_.margins
[docs]
def get_config_for_pandora(self, cfg: dict) -> dict[str, str]:
"""
Get configuration for matching_cost in pandora
:param cfg: user_config for matching cost pandora2d
:return: cfg: matching cost pandora configuration
"""
copy_cfg = copy.deepcopy(cfg)
copy_cfg["step"] = self._step_col
if "float_precision" in cfg:
del copy_cfg["float_precision"]
if copy_cfg["matching_cost_method"] == "zncc_python":
copy_cfg["matching_cost_method"] = "zncc"
return copy_cfg
[docs]
def allocate(
self,
img_left: xr.Dataset,
img_right: xr.Dataset,
cfg: dict,
margins: Margins = None,
) -> None:
"""
Allocate the cost volume for pandora
:param img_left: xarray.Dataset containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param img_right: xarray.Dataset containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param cfg: matching_cost computation configuration
:param margins: refinement margins
:return: None
"""
grid_min_col = img_left["col_disparity"].sel(band_disp="min").data.copy()
grid_max_col = img_left["col_disparity"].sel(band_disp="max").data.copy()
# Replace invalid initial disparity values by np.nan
# so that they are not included in the 3D Pandora cost-volume allocation
no_data_disp = img_left["col_disparity"].attrs["no_data"]
grid_min_col = np.where(build_usable_data_mask(grid_min_col, no_data_disp), grid_min_col, np.nan)
grid_max_col = np.where(build_usable_data_mask(grid_max_col, no_data_disp), grid_max_col, np.nan)
if margins is not None:
grid_min_col -= margins.left
grid_max_col += margins.right
# Get updated ROI left margin for pandora method get_coordinates()
# To get right coordinates in cost_volume when initial left_margin > cfg["ROI"]["col"]["first"]
# We need to have left_margin = cfg["ROI"]["col"]["first"]
cfg_for_get_coordinates = BaseMatchingCost.cfg_for_get_coordinates(cfg)
# Initialize pandora an empty grid for cost volume
self.grid = self.pandora_matching_cost_.allocate_cost_volume(
img_left, (grid_min_col, grid_max_col), cfg_for_get_coordinates
)
super().allocate(img_left, img_right, cfg, margins)
[docs]
def set_shifted_right_images(self, img_right: xr.Dataset) -> None:
"""
Compute shifted by subpix right image and assign `shifted_right_images` attribute.
:param img_right: xarray.Dataset containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:return: None
"""
self.shifted_right_images = img_tools.shift_subpix_img(img_right, self._subpix, order=self._spline_order)
[docs]
def compute_cost_volumes(
self,
img_left: xr.Dataset,
img_right: xr.Dataset,
margins: Margins = None,
) -> xr.Dataset:
"""
Computes the cost volumes
:param img_left: xarray.Dataset containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param img_right: xarray.Dataset containing :
- im : 2D (row, col) xarray.DataArray
- msk : 2D (row, col) xarray.DataArray
:param margins: refinement margins
:return: cost_volumes: 4D Dataset containing the cost_volumes
"""
grid_min_col = img_left["col_disparity"].sel(band_disp="min").data.copy()
grid_max_col = img_left["col_disparity"].sel(band_disp="max").data.copy()
if margins is not None:
grid_min_col -= margins.left
grid_max_col += margins.right
# Get disparity coordinates for cost_volumes
disps_row = self.get_disp_row_coords(img_left, margins)
row_index = self.cost_volumes.coords["row"] - img_left.coords["row"].data[0]
for idx, disp_row in enumerate(disps_row):
i_right = int((disp_row % 1) * self._subpix)
# Images contained in self.shifted_right_images are already shifted by 1/subpix.
# In order for img_right_shift to contain the right image shifted from disp_row,
# we call img_tools.shift_disp_row_img with np.floor(disp_row).
# For example if subpix=2 and disp_row=1.5
# i_right=1
# self.shifted_right_images[i_right] is shifted by 0.5
# In img_tools.shift_disp_row_img we shift it by np.floor(1.5)=1 --> In addition it is shifted by 1.5
# Another example if subpix=4 and disp_row=-1.25
# i_right=3
# self.shifted_right_images[i_right] is shifted by 0.75
# In img_tools.shift_disp_row_img we shift it by np.floor(-1.25)=-2 --> In addition it is shifted by -1.25
# Shift image in the y axis
img_right_shift = img_tools.shift_disp_row_img(self.shifted_right_images[i_right], np.floor(disp_row))
# Compute cost volume
cost_volume = self.pandora_matching_cost_.compute_cost_volume(img_left, img_right_shift, self.grid)
# Add current cost volume to the cost_volumes dataset
self.cost_volumes["cost_volumes"].data[:, :, idx, :] = cost_volume["cost_volume"].data[row_index, :, :]
# Add type measure to attributes for WTA
self.cost_volumes.attrs["type_measure"] = cost_volume.attrs["type_measure"]
# Value to set on cost volumes points where the row or column disparity is out of the range defined
# by disparity grids
value = -np.inf if self.cost_volumes.attrs["type_measure"] == "max" else np.inf
self.set_out_of_disparity_range_to_other_value(img_left, value)
return self.cost_volumes