Source code for pandora2d.refinement.optical_flow

#!/usr/bin/env python
#
# Copyright (c) 2026 CS GROUP France
#
# This file is part of PANDORA2D
#
#     https://github.com/CNES/Pandora2D
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains functions associated to the optical flow method used in the refinement step.
"""

from typing import Literal

import numpy as np
import xarray as xr
from json_checker import And
from numpy.typing import NDArray
from scipy.ndimage import map_coordinates

import pandora2d.schema as cst_schema
from pandora2d.margins import Margins

from . import refinement


@refinement.AbstractRefinement.register_subclass("optical_flow")
[docs] class OpticalFlow(refinement.AbstractRefinement): """ OpticalFLow class allows to perform the subpixel cost refinement step """
[docs] _invalid_disp = None
[docs] _ITERATIONS = 4
[docs] schema = { "refinement_method": And(str, lambda x: x in ["optical_flow"]), "iterations": And(int, lambda it: it > 0), "window_size": And(int, lambda input: input > 1 and (input % 2) != 0), "step": cst_schema.STEP_SCHEMA, }
def __init__(self, cfg: dict = None, step: list = None, window_size: int = 5) -> None: """ :param cfg: optional configuration, {} :param step: list containing row and col step :param window_size: window size :return: None """ # Update user configuration with step and window_size parameters to check them cfg["window_size"] = window_size cfg["step"] = [1, 1] if step is None else step super().__init__(cfg)
[docs] self._iterations = self.cfg["iterations"]
[docs] self._refinement_method = self.cfg["refinement_method"]
[docs] self._window_size = self.cfg["window_size"]
[docs] self._step = self.cfg["step"]
@classmethod
[docs] def check_conf(cls, cfg: dict) -> dict: """ Check the refinement configuration :param cfg: user_config for refinement :return: cfg: global configuration """ cfg["iterations"] = cfg.get("iterations", cls._ITERATIONS) cfg = super().check_conf(cfg) return cfg
@property
[docs] def margins(self): values = (self._window_size // 2 * ele for _ in range(2) for ele in self._step) return Margins(*values)
[docs] def reshape_to_matching_cost_window( self, img: xr.Dataset, cost_volumes: xr.Dataset, coordinates: tuple[list, list], disp_row: NDArray = None, disp_col: NDArray = None, ) -> NDArray: """ Transform image from (nb_col, nb_row) to (window_size, window_size, nbcol*nbrow) :param img: image to reshape :param cost_volumes: cost_volumes 4D row, col, disp_col, disp_row :param coordinates: min and max index coordinate for row and col [(first_row,last_row),(first_col,last_col)] :param disp_row: array dim [] containing all the row shift :param disp_col: array dim [] containing all the columns shift :return: array containing reshaped image [window_size, window_size, nbcol*nbrow] """ # get numpy array data for image img_data = img["im"].data # get general offset value offset = cost_volumes.offset_row_col # get cost volume sub xarray with offset coordinates values offset_row, offset_col = coordinates cost_volumes_sub = cost_volumes.sel( row=slice(offset_row[0], offset_row[-1]), col=slice(offset_col[0], offset_col[-1]) ) # get computable cost volume data in row and col computable_col = cost_volumes_sub.col.data computable_row = cost_volumes_sub.row.data if disp_row is None and disp_col is None: # define image patches in one dim patches = np.lib.stride_tricks.sliding_window_view(img_data, [self._window_size, self._window_size]) flattened_patches = patches.reshape(-1, self._window_size, self._window_size) # get patches id from original image id_patches_img = [ int(c_row * img.sizes["col"]) + c_col for c_row in img["row"].data[offset:-offset] for c_col in img["col"].data[offset:-offset] ] # Associate each patches of the one dim image to the id of the true image patches patch_dict = {id_patches_img[i]: flattened_patches[i] for i in range(len(id_patches_img))} id_patches = [int(c_row * img.sizes["col"]) + c_col for c_row in computable_row for c_col in computable_col] # Filter patches to keep only id calculated with offset and step filtered_patches_list = [patch_dict[key] for key in id_patches if key in patch_dict] reshaped_patches = np.stack(filtered_patches_list, axis=-1).reshape( (self._window_size, self._window_size, len(filtered_patches_list)) ) return reshaped_patches # initiate values for right reshape computation offset = self._window_size // 2 patches = np.ndarray((self._window_size, self._window_size, len(computable_row) * len(computable_col))) idx = 0 for row in computable_row: for col in computable_col: shift_col = 0 if np.isnan(disp_col[idx]) or disp_col[idx] == self._invalid_disp else int(disp_col[idx]) shift_row = 0 if np.isnan(disp_row[idx]) or disp_row[idx] == self._invalid_disp else int(disp_row[idx]) # get right pixel with his matching cost window patch_row_start = row - offset + shift_row patch_row_end = row + offset + shift_row patch_col_start = col - offset + shift_col patch_col_end = col + offset + shift_col patch = img.sel(row=slice(patch_row_start, patch_row_end), col=slice(patch_col_start, patch_col_end)) patch = patch["im"].data # stock matching_cost window if patch.shape == (self._window_size, self._window_size): patches[:, :, idx] = patch else: patches[:, :, idx] = np.ones([self._window_size, self._window_size]) * np.nan idx += 1 return patches
[docs] def warped_img( self, right_reshape: NDArray, delta_row: NDArray, delta_col: NDArray, index_to_compute: list ) -> NDArray: """ Shifted matching_cost window with computed disparity :param right_reshape: image right reshaped with dims (window_size, window_size, nbcol*nb_row) :param delta_row: rows disparity map :param delta_col: columns disparity map :param index_to_compute: list containing all valid pixel for computing optical flow :return: new array containing shifted matching_cost windows """ x, y = np.meshgrid(range(self._window_size), range(self._window_size)) new_img = np.empty_like(right_reshape) # resample matching cost right windows for idx in index_to_compute: shifted_img = map_coordinates( right_reshape[:, :, idx], [y - delta_row[idx], x - delta_col[idx]], order=5, mode="reflect" ) new_img[:, :, idx] = shifted_img return new_img
[docs] def lucas_kanade_core_algorithm(self, left_data: NDArray, right_data: NDArray) -> tuple[float, float]: """ Implement lucas & kanade algorithm core :param left_data: matching_cost window for one pixel from left image :param right_data: matching_cost window for one pixel from left image :return: sub-pixel disparity computed by Lucas & Kanade optical flow """ grad_y, grad_x = np.gradient(left_data) grad_t = right_data - left_data # Create A (grad_matrix) et B (time_matrix) matrix for Lucas Kanade grad_matrix = np.vstack((grad_x.flatten(), grad_y.flatten())).T time_matrix = grad_t.flatten() # Apply least-squares to solve the matrix equation AV= B where A is matrix containing partial derivate of (x,y) # B the matrix of partial derivate of t and V the motion we want to find try: motion = np.linalg.lstsq(grad_matrix, time_matrix, rcond=None)[0] # if matrix is full of NaN or 0 except np.linalg.LinAlgError: motion = (self._invalid_disp, self._invalid_disp) return motion[1], motion[0]
[docs] def optical_flow( self, left_img: NDArray, right_img: NDArray, list_idx_to_compute: list, ) -> tuple[NDArray, NDArray, list]: """ Computing optical flow between left and right image :param left_img: reshaped left image array :param right_img: reshaped right image array :param list_idx_to_compute: list of valid pixel :return: computed sub-pixel disparity map """ new_list_to_compute = [] final_dec_row = np.zeros(left_img.shape[2]) final_dec_col = np.zeros(left_img.shape[2]) for idx in list_idx_to_compute: left_matching_cost = left_img[:, :, idx] right_matching_cost = right_img[:, :, idx] computed_delta_row, computed_delta_col = self.lucas_kanade_core_algorithm( left_matching_cost, right_matching_cost ) # hypothesis from algorithm: shifts are < 1 if abs(computed_delta_col) < 1 and abs(computed_delta_row) < 1: new_list_to_compute.append(idx) else: if abs(computed_delta_col) > 1: computed_delta_col = 0 if abs(computed_delta_row) > 1: computed_delta_row = 0 final_dec_row[idx] = computed_delta_row final_dec_col[idx] = computed_delta_col return final_dec_row, final_dec_col, new_list_to_compute
@staticmethod
[docs] def find_nearest_column(value: int, data: NDArray, direction: Literal["+", "-"]) -> NDArray: """ Return the nearest column from initial column index coordinate in a given direction :param value: initial column index :param data: cost volume coordinates :param direction: direction sign (must be + or -) """ if direction == "+": return data[np.searchsorted(data, value, side="left")] if direction == "-": return data[np.searchsorted(data, value, side="right") - 1] raise ValueError("Direction must be '+' or '-'")
[docs] def refinement_method( self, cost_volumes: xr.Dataset, disp_map: xr.Dataset, img_left: xr.Dataset, img_right: xr.Dataset ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Return the subpixel disparity maps :param cost_volumes: cost_volumes 4D row, col, disp_col, disp_row :param disp_map: pixels disparity maps :param img_left: left image dataset :param img_right: right image dataset :return: the refined disparity maps and disparity correlation score """ # get invalid_disp value self._invalid_disp = disp_map.attrs["invalid_disp"] # get offset offset = cost_volumes.offset_row_col # get first and last coordinates for row and col in cost volume dataset first_col_coordinate = cost_volumes.col.data[0] + offset last_col_coordinate = cost_volumes.col.data[-1] - offset col_extrema_coordinates = [ self.find_nearest_column(first_col_coordinate, cost_volumes.col.data, "+"), self.find_nearest_column(last_col_coordinate, cost_volumes.col.data, "-"), ] first_row_coordinate = cost_volumes.row.data[0] + offset last_row_coordinate = cost_volumes.row.data[-1] - offset row_extrema_coordinates = [ self.find_nearest_column(first_row_coordinate, cost_volumes.row.data, "+"), self.find_nearest_column(last_row_coordinate, cost_volumes.row.data, "-"), ] # get displacement map in row and col - from disparity min/max coordinates row_slice = slice(row_extrema_coordinates[0], row_extrema_coordinates[-1]) col_slice = slice(col_extrema_coordinates[0], col_extrema_coordinates[-1]) cost_volume_sub = cost_volumes.sel(row=row_slice, col=col_slice) disp_map_sub = disp_map.sel(row=cost_volume_sub.row, col=cost_volume_sub.col) delta_row = disp_map_sub["row_map"].data.flatten() delta_col = disp_map_sub["col_map"].data.flatten() # reshape left and right data # from (nbcol, nbrow) to (window_size, window_size, nbcol*nbrow) reshaped_left = self.reshape_to_matching_cost_window( img_left, cost_volumes, (row_extrema_coordinates, col_extrema_coordinates) ) reshaped_right = self.reshape_to_matching_cost_window( img_right, cost_volumes, (row_extrema_coordinates, col_extrema_coordinates), delta_row, delta_col, ) idx_to_compute = np.arange(reshaped_left.shape[2]).tolist() for _ in range(self._iterations): computed_drow, computed_dcol, idx_to_compute = self.optical_flow( reshaped_left, reshaped_right, idx_to_compute ) reshaped_right = self.warped_img(reshaped_right, computed_drow, computed_dcol, idx_to_compute) # Pandora convention is left - d = right # Lucas&Kanade convention is left + d = right delta_col = delta_col - computed_dcol delta_row = delta_row - computed_drow # get finals disparity map dimensions, add +1 because it began at 0 nb_valid_points_row = int((row_extrema_coordinates[-1] - row_extrema_coordinates[0]) / cost_volumes.step[0] + 1) nb_valid_points_col = int((col_extrema_coordinates[-1] - col_extrema_coordinates[0]) / cost_volumes.step[1] + 1) delta_col = delta_col.reshape([nb_valid_points_row, nb_valid_points_col]) delta_row = delta_row.reshape([nb_valid_points_row, nb_valid_points_col]) # add border padding_top = (disp_map.sizes["row"] - delta_row.shape[0]) // 2 padding_bottom = disp_map.sizes["row"] - delta_row.shape[0] - padding_top padding_left = (disp_map.sizes["col"] - delta_row.shape[1]) // 2 padding_right = disp_map.sizes["col"] - delta_row.shape[1] - padding_left delta_row = np.pad( delta_row, pad_width=((padding_top, padding_bottom), (padding_left, padding_right)), constant_values=self._invalid_disp, ) delta_col = np.pad( delta_col, pad_width=((padding_top, padding_bottom), (padding_left, padding_right)), constant_values=self._invalid_disp, ) self._invalid_out_of_grid_disparities(cost_volumes.attrs["step"], delta_col, img_left["col_disparity"]) self._invalid_out_of_grid_disparities(cost_volumes.attrs["step"], delta_row, img_left["row_disparity"]) return delta_col, delta_row, disp_map["correlation_score"].data
[docs] def _invalid_out_of_grid_disparities( self, step: list, delta: NDArray[np.floating], disparity: xr.DataArray ) -> None: """ Replace delta values by invalid_disp value when it is outside the corresponding disparity range defined by the disparity grid. :param step: [row_step, col_step] :param delta: refined disparity map :param disparity: pixelic disparity grids with min and max `band_disp` coordinates. """ delta[delta <= disparity.sel(band_disp="min").data[:: step[0], :: step[1]]] = self._invalid_disp delta[delta >= disparity.sel(band_disp="max").data[:: step[0], :: step[1]]] = self._invalid_disp