Skip to content

optimise

Parameter optimization routines for page dewarping.

This subpackage provides:

  • A function (draw_correspondences) to visualize matched points (dstpoints/projpts).
  • A function (optimise_params) that uses an objective function and optimization to refine the parameter vector for page dewarping.

The implementation dispatches to either JAX (if available) or SciPy based on the configured optimization method and available backends.

_base

Shared utilities for optimization backends.

draw_correspondences

draw_correspondences(img: ndarray, dstpoints: ndarray, projpts: ndarray) -> np.ndarray

Draw matching points (projected vs. desired) on a copy of the image.

Source code in src/page_dewarp/optimise/_base.py
def draw_correspondences(
    img: np.ndarray,
    dstpoints: np.ndarray,
    projpts: np.ndarray,
) -> np.ndarray:
    """Draw matching points (projected vs. desired) on a copy of the image."""
    display = img.copy()
    dstpoints = norm2pix(img.shape, dstpoints, True)
    projpts = norm2pix(img.shape, projpts, True)
    for pts, color in [(projpts, (255, 0, 0)), (dstpoints, (0, 0, 255))]:
        for point in pts:
            circle(display, fltp(point), 3, color, -1, LINE_AA)
    for point_a, point_b in zip(projpts, dstpoints):
        line(display, fltp(point_a), fltp(point_b), (255, 255, 255), 1, LINE_AA)
    return display

make_objective

make_objective(dstpoints: ndarray, keypoint_index: ndarray, shear_cost: float, rvec_slice: slice) -> Callable[[np.ndarray], float]

Create an objective function for scipy.optimize.minimize.

Parameters:

Name Type Description Default
dstpoints ndarray

Target points to match.

required
keypoint_index ndarray

Index array for keypoint extraction.

required
shear_cost float

Cost coefficient for shear penalty.

required
rvec_slice slice

Slice object for extracting rotation vector from params.

required

Returns:

Type Description
Callable[[ndarray], float]

Objective function suitable for scipy.optimize.minimize.

Source code in src/page_dewarp/optimise/_base.py
def make_objective(
    dstpoints: np.ndarray,
    keypoint_index: np.ndarray,
    shear_cost: float,
    rvec_slice: slice,
) -> Callable[[np.ndarray], float]:
    """Create an objective function for scipy.optimize.minimize.

    Args:
        dstpoints: Target points to match.
        keypoint_index: Index array for keypoint extraction.
        shear_cost: Cost coefficient for shear penalty.
        rvec_slice: Slice object for extracting rotation vector from params.

    Returns:
        Objective function suitable for scipy.optimize.minimize.

    """
    from ..keypoints import project_keypoints

    def objective(pvec: np.ndarray) -> float:
        ppts = project_keypoints(pvec, keypoint_index)
        error = np.sum((dstpoints - ppts) ** 2)
        if shear_cost > 0.0:
            rvec = pvec[rvec_slice]
            rotation_penalty = shear_cost * rvec[0] ** 2
            return error + rotation_penalty
        else:
            return error

    return objective

_jax

JAX-based optimization backend for page dewarping.

This module provides accelerated optimization using JAX's automatic differentiation for computing gradients, enabling efficient L-BFGS-B optimization.

optimise_params_jax

optimise_params_jax(name: str, small: ndarray, dstpoints: ndarray, span_counts: list[int], params: ndarray, debug_lvl: int) -> np.ndarray

Refine the parameter vector using JAX-accelerated L-BFGS-B optimization.

Parameters:

Name Type Description Default
name str

Image name for debug output.

required
small ndarray

Downscaled image for visualization.

required
dstpoints ndarray

Target points to match.

required
span_counts list[int]

Number of keypoints per span.

required
params ndarray

Initial parameter vector.

required
debug_lvl int

Debug verbosity level.

required

Returns:

Type Description
ndarray

Optimized parameter vector.

Source code in src/page_dewarp/optimise/_jax.py
def optimise_params_jax(
    name: str,
    small: np.ndarray,
    dstpoints: np.ndarray,
    span_counts: list[int],
    params: np.ndarray,
    debug_lvl: int,
) -> np.ndarray:
    """Refine the parameter vector using JAX-accelerated L-BFGS-B optimization.

    Args:
        name: Image name for debug output.
        small: Downscaled image for visualization.
        dstpoints: Target points to match.
        span_counts: Number of keypoints per span.
        params: Initial parameter vector.
        debug_lvl: Debug verbosity level.

    Returns:
        Optimized parameter vector.

    """
    device = get_device(cfg.DEVICE)
    keypoint_index = make_keypoint_index(span_counts)

    # Use the base objective for initial value display
    objective = make_objective(
        dstpoints,
        keypoint_index,
        cfg.SHEAR_COST,
        slice(*cfg.RVEC_IDX),
    )

    print("  initial objective is", objective(params))
    if debug_lvl >= 1:
        projpts = project_keypoints(params, keypoint_index)
        display = draw_correspondences(small, dstpoints, projpts)
        debug_show(name, 4, "keypoints before", display)

    print(f"  optimizing {len(params)} parameters on {device.device_kind.upper()}...")

    start = dt.now()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with jax.default_device(device):
            result = _run_jax_lbfgsb(dstpoints, keypoint_index, params)
    elapsed = (dt.now() - start).total_seconds()
    print(
        f"  optimization (L-BFGS-B + JAX autodiff) took {elapsed:.2f}s, {result.nfev} evals",
    )

    print(f"  final objective is {result.fun:.6f}")
    params = result.x

    if debug_lvl >= 1:
        _print_diagnostics(params, keypoint_index, small, dstpoints, name)

    return params

_scipy

SciPy-based optimization backend for page dewarping.

optimise_params_scipy

optimise_params_scipy(name: str, small: ndarray, dstpoints: ndarray, span_counts: list[int], params: ndarray, debug_lvl: int, method: str) -> np.ndarray

Refine the parameter vector using SciPy optimization.

Parameters:

Name Type Description Default
name str

Image name for debug output.

required
small ndarray

Downscaled image for visualization.

required
dstpoints ndarray

Target points to match.

required
span_counts list[int]

Number of keypoints per span.

required
params ndarray

Initial parameter vector.

required
debug_lvl int

Debug verbosity level.

required
method str

Optimization method (e.g., 'Powell', 'L-BFGS-B').

required

Returns:

Type Description
ndarray

Optimized parameter vector.

Source code in src/page_dewarp/optimise/_scipy.py
def optimise_params_scipy(
    name: str,
    small: np.ndarray,
    dstpoints: np.ndarray,
    span_counts: list[int],
    params: np.ndarray,
    debug_lvl: int,
    method: str,
) -> np.ndarray:
    """Refine the parameter vector using SciPy optimization.

    Args:
        name: Image name for debug output.
        small: Downscaled image for visualization.
        dstpoints: Target points to match.
        span_counts: Number of keypoints per span.
        params: Initial parameter vector.
        debug_lvl: Debug verbosity level.
        method: Optimization method (e.g., 'Powell', 'L-BFGS-B').

    Returns:
        Optimized parameter vector.

    """
    keypoint_index = make_keypoint_index(span_counts)
    objective = make_objective(
        dstpoints,
        keypoint_index,
        cfg.SHEAR_COST,
        slice(*cfg.RVEC_IDX),
    )

    print("  initial objective is", objective(params))
    if debug_lvl >= 1:
        projpts = project_keypoints(params, keypoint_index)
        display = draw_correspondences(small, dstpoints, projpts)
        debug_show(name, 4, "keypoints before", display)

    print(f"  optimizing {len(params)} parameters...")

    # Hint about JAX for L-BFGS-B
    if method == "L-BFGS-B":
        print(
            "OPT_METHOD='L-BFGS-B' can use JAX for fast autodiff. "
            "Install with: pip install page-dewarp[jax]",
            file=sys.stderr,
        )

    start = dt.now()
    result = minimize(
        objective,
        params,
        method=method,
        options={"maxiter": cfg.OPT_MAX_ITER},
    )
    elapsed = (dt.now() - start).total_seconds()
    print(f"  optimization ({method}) took {elapsed:.2f}s, {result.nfev} evals")

    print(f"  final objective is {result.fun:.6f}")
    params = result.x

    if debug_lvl >= 1:
        _print_diagnostics(params, keypoint_index, small, dstpoints, name)

    return params