Factory Method

python
factory method
Author

Neil Shephard

Published

March 26, 2023

The Factory Method is a simple and widely used design pattern.

The idea is that you wish to carry out a task but there are different ways of achieving this. For example you may wish to open files, but there are different file types that may need opening. Or you may wish to apply some sort of filter to image data, but there are different filters to choose from. The user should have a single command to call, but in the background the work that is done is chosen based on the supplied data/parameters. This means that it is possible to extend the supported functions without changing user interface and means the design pattern helps fulfill the S and O criteria of the SOLID design principles, that is Single-responsibility and Open-closed.

Worked Example

As a worked example we will use the different threshold filtering methods available from the image processing package scikit-image. There are a number of filters available and typically they take a Numpy array which represents an image and process it in the given manner.

Define Interface

This is the function that users will call, it takes two arguments, data which is a np.ndarray (Numpy array) representing the image for which a threshold for filtering is to be derived, and method the threshold method to use.

"""Factory Method Example"""
from typing import Callable
import numpy as np
from skimage.filters import (
    threshold_local,
    threshold_mean,
    threshold_otsu,
    threshold_yen,
    threshold_triangle,
)


def threshold(data: np.ndarray, method: str, **kwargs: dict) -> float:
    """Derive threshold for filtering of the given array using the specified filter.

    Parameters
    ----------
    data: np.ndarray
        Numpy array representing the image to be filtered.
    filter: str
        Filter method from Scikit-image to apply.

    Returns
    -------
    np.ndarray
        Filtered array.
    """
    filter_threshold = _get_threshold(method)
    return filter_threshold(data, **kwargs)

Define Private Function

Next we define a private _get_threshold() which determines based on

def _get_threshold(method: str) -> Callable:
    """Creator component which determines which filter method to use.

    Parameters
    ----------
    data: np.ndarray
        Numpy array representing the image to be filtered.
    filter: str
        Filter method from Scikit-image to apply.

    Returns
    -------
    np.ndarray
        Filtered array.
    """
    if method == "mean":
        return _mean
    if method == "otsu":
        return _otsu
    if method == "triangle":
        return _triangle
    if method == "yen":
        return _yen
    raise ValueError(method)

Define filters

We now define each of the functions which

def _mean(data: np.ndarray, **kwargs: dict) -> float:
    """Threshold based on mean method."""
    return threshold_mean(data, **kwargs)


def _otsu(data: np.ndarray, **kwargs: dict) -> float:
    """Threshold based on Otsu's method."""
    return threshold_otsu(data, **kwargs)


def _triangle(data: np.ndarray, **kwargs: dict) -> float:
    """Threshold based on triangle method."""
    return threshold_triangle(data, **kwargs)


def _yen(data: np.ndarray, **kwargs: dict) -> float:
    """Threshold based on Yen's method."""
    return threshold_yen(data, **kwargs)

Usage

Now that the interface has been defined users only need to import the single function threshold()

import numpy as np

from python_design_patterns import threshold

data = np.random(100, 100, seed=52449807)

filter_threshold = threshold(data, method="otsu")

Tests

Being diligent programmers a short test-suite is given by the following.

"""Tests for Factory Method."""
import numpy as np
import pytest

from python_design_patterns.factory_method import threshold


rng = np.random.default_rng(seed=501472)
pytest.random_array = rng.random((10, 10))


@pytest.mark.parametrize(
    "data,method,expected_threshold",
    [
        (pytest.random_array, "mean", 0.45497009177756903),
        (pytest.random_array, "otsu", 0.4355642854559527),
        (pytest.random_array, "triangle", 0.38976173291463245),
        (pytest.random_array, "yen", 0.3706773360224157),
    ],
)
def test_threshold(data: np.ndarray, method: str, expected_threshold: float) -> None:
    assert threshold(data, method) == expected_threshold

Extending

We decide we want to extend the threshold() function to allow Li’s method (skimage.filters.threshold_li()) doing so is now really simple. we don’t need to change the threshold() function itself (other than to list li as an option in the docstring), but we add an elif method == "li": clause to _get_threshold() that returns a Callable of a new function _li() which returns the result of using skimage.filters.threshold_li(). NB don’t forget you have to import the threshold_li method too.

from skimage.filters import (
    ...
    threshold_li
)

def _get_threshold(method: str) -> Callable:
    """Creator component which determines which filter method to use.

    Parameters
    ----------
    data: np.ndarray
        Numpy array representing the image to be filtered.
    filter: str
        Filter method from Scikit-image to apply.

    Returns
    -------
    np.ndarray
        Filtered array.
    """
    if method == "mean":
        return _mean
    if method == "otsu":
        return _otsu
    if method == "triangle":
        return _triangle
    if method == "yen":
        return _yen
    if method == "li":    # New method
        return _li
    raise ValueError(method)


def _li(data: np.ndarray, **kwargs: dict) -> float:
    """Threshold based on Li method."""
    return threshold_li(data, **kwargs)

And we add a parameter to the tests.

Conclusion

The Factory Method is a simple design pattern that is very flexible and easy to understand. For an example please refer to the associated repository python-design-patterns where you will find the above code.