Efficient type validation for Python functions

Written by Matt Sosna on April 18, 2021

When it comes to writing complex pipelines running in production, it’s critical to have a clear understanding of what each function does, and how its outputs affect downstream functions. But despite our best efforts to write modular, well-tested functions, bugs love hiding in the handoffs between functions, and they can be hard to catch even with end-to-end tests.

This post will cover a Python decorator for input validation, which we can use to “lock” the inputs to our functions and immediately notice when there’s an unexpected mismatch.

Example pipeline

Consider a simple pipeline where we query an API, clean the data we get back, then save a CSV. We have one main function a user can trigger, run_pipeline, that takes a date range, runs query_api and process_dict, then saves a CSV.

Python
1
2
3
4
5
6
7
8
9
10
11
import datetime as dt

def run_pipeline(start_date: dt.datetime,
                 end_date: dt.datetime) -> None:
    """
    Query API for data between dates, then process output.
    Note: end_date exclusive.
    """
    response = query_api(start_date, end_date)
    df = process_dict(response)
    df.to_csv('cleaned_data.csv')

Simple enough, right? Let’s now look at query_api and process_dict. While they also look straightforward, there are actually two points of failure, one easier to spot than the other. Can you see them?

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import logging
import requests
import pandas as pd
import datetime as dt

DATA_URL = "http://ourcompany.com/data/start_date={}&end_date={}"


def query_api(start_date: dt.datetime,
              end_date: dt.datetime) -> dict:
    """
    Query data API for values between selected dates.
    Note: end_date exclusive.
    """

    start_date = dt.datetime.strftime(start_date, '%Y-%m-%d')
    end_date = dt.datetime.strftime(end_date, '%Y-%m-%d')

    try:
        data = requests.get(DATA_URL.format(start_date, end_date))
        return data.json()
    except:
        logging.error("Could not retrieve data")


def process_dict(data: dict):
    """
    Convert dict to df and replace values in 'sales' column.
    """
    df = pd.DataFrame({'sales': data['sales']}, index=data['date'])
    df_filt['sales'] = df['sales'].replace({'nan': pd.NA, '': pd.NA})

    return df_filt

Failure 1: Non-datetime inputs to query_api

You might have caught the first point of failure, in query_api: what if the inputs aren’t datetime objects?

If start_date and end_date aren’t datetime values, we’ll fail immediately on lines 16-17, when we try to convert those objects to strings. There’s nothing stopping this from happening, as run_pipeline (which is the function actually exposed to the user) just passes its inputs straight into query_api without second thought.

We could fix this first issue by checking the inputs to run_pipeline and making sure they’re datetime objects, immediately alerting the user of the issue before continuing the pipeline.

But what about for our second issue? Have you found it yet?

Failure 2: Query failure in query_api

Requesting data from an API doesn’t always work. The server could be offline, the URL could be invalid, or our request could cause an internal error on the server side. To address this variability and avoid crashing our pipeline, we put our API request inside a try-except block.

Great… except for what happens if the except block gets triggered.

If something goes wrong with our API request, we log an error message… but don’t explicitly return anything. By default, Python functions return None. If our code takes the except branch, query_api returns None, which is then passed into process_dict. process_dict then crashes when we try to get the sales and date fields of what’s supposed to be a dictionary.

These issues will inevitably pop up as we start running our pipeline in earnest, but ideally we catch them early in development.

Solution

We can catch bugs faster if we keep a close eye on what values are going into our functions. While type hints indicate what data type we expect, Python won’t stop wrong inputs from being passed in. We therefore need to write the logic ourselves for validating the inputs are correct.

There are a few ways to validate inputs: a simple way is some sort of helper method in the function like _validate_query_api_inputs below:

Python
1
2
3
4
5
6
7
8
9
10
11
12
def query_api(start_date: dt.datetime,
              end_date: dt.datetime) -> dict:
    """
    Query data API for values between selected dates.
    Note: end_date exclusive.
    """
    _validate_query_api_inputs(start_date, end_date)
    ...

def _validate_query_api_inputs(start_date, end_date):
    assert isinstance(start_date, dt.datetime)
    assert isinstance(end_date, dt.datetime)

But if all we’re doing is checking whether the inputs match their type hints, we probably want a more generic function; otherwise, we’ll need to write _validate functions for run_pipeline, query_api, and process_dict that all effectively do the same thing.

To address these needs, I wrote a decorator called enforce_type_hints. This decorator iterates through the inputs to a function and confirms they match the provided type hints. If there’s a mismatch, enforce_type_hints raises an AssertionError. Here’s a simple example:

Python
1
2
3
4
5
6
7
8
9
10
11
12
@enforce_type_hints
def query_api(start_date: dt.datetime,
              end_date: dt.datetime):
  ...

# Example 1: correct inputs
query_api(dt.datetime(2020, 1, 1), dt.datetime(2020, 1, 2))
    # {...} returned

# Example 2: incorrect inputs
query_api('2020-01-01', '2020-01-02')
    # AssertionError: start_date is type <class 'str'>

Let’s now check out how enforce_type_hints works. Want to skip the tutorial and see the raw code? Check out the jewelry GitHub repo, and submit a PR if you see a way to make it better!

ArgChecker

Because there are multiple components to running our type checks, I’ve put enforce_type_hints inside a Python class called ArgChecker. ArgChecker looks something like this:

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class ArgChecker:
    """
    Methods for enforcing type hints on decorated functions
    """

    def enforce_type_hints(self, func):
        """
        Ensure arguments passed into a function match the provided
        type hints
        """
        # 1. Get arg names and type hints from function
        # 2. Call _make_typings_comparable
        # 3. Call _check_positional_args
        # 4. Call _check_keyword_args
        # 5. Return func if no errors

    def _make_typings_comparable(self,
                                 types_dict: dict) -> dict:
        """
        Convert types from the typing module (e.g. Dict, Tuple) to
        Python natives
        """
        # 1. Convert basic Typing types to Python natives
            # List -> list, Dict -> dict, Tuple -> tuple
        # 2. Extract info from Union & Optional types
            # e.g. Optional[int] -> (int, type(None))
        # 3. Add np.int64 if type is int
            # e.g. (int, str) -> (int, np.int64, str)
        # 4. Return dict of arg name and accepted base Python types

    def _check_positional_args(self,
                               args: Tuple[Any],
                               arg_names: list,
                               type_hints: dict) -> None:
        """
        Check whether all argument types in args tuple match argument
        types in type_hints dict
        """
        # 1. Iterate through positional args
        # 2. Skip if arg name is 'self'
        # 3. Skip if type hint is Typing.Any
        # 4. Raise AssertionError if mismatch

    def _check_keyword_args(self,
                       kwargs: dict,
                       type_hints: dict) -> None:
        """
        Check whether all argument types in kwargs dict match argument
        types in type_hints dict
        """
        # 1. Iterate through keyword args
        # 2. Skip if type hint is Typing.Any
        # 3. Raise AssertionError if mismatch

In short, enforce_type_hints gets the arguments and type hints from the function, converts the type hints to native Python types if necessary, and then iteratively checks that the positional and keyword arguments match their type hints. This check occurs as an assert isinstance(arg, accepted_types) call.

Fixing our pipeline

By decorating our functions with enforce_type_hints, we’re guaranteed that any code executing inside the function is the data type provided in the argument’s type hint.

Yes, raising an AssertionError grinds our pipeline to a halt. But it immediately points us to where the error occurs, before it can propagate downstream. And removing any room for interpretation in our function inputs means we’re forced to think hard about what, exactly, each function is supposed to do (as well as the functions before and after it).

But back to our original pipeline. Here’s how I’d change our pipeline to make it more robust.

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import datetime as dt

from jewelry import ArgChecker   # <------ new
ac = ArgChecker()                # <------ new

@ac.enforce_type_hints           # <------ new
def run_pipeline(start_date: dt.datetime,
                 end_date: dt.datetime):
    """
    Query API for data between dates, then process output.
    Note: end_date exclusive.
    """
    response = query_api(start_date, end_date)
    df = process_dict(response)
    df.to_csv('cleaned_data.csv')

It’s critical to get the inputs to our pipeline correct, as it saves a lot of error handling downstream. By locking the inputs with enforce_type_hints, we immediately stop the pipeline from executing if the inputs aren’t valid.

Now let’s change query_api and process_dict.

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import logging
import requests
import pandas as pd
import datetime as dt

DATA_URL = "http://ourcompany.com/data/start_date={}&end_date={}"


@ac.enforce_type_hints    # <-------- new
def query_api(start_date: dt.datetime,
              end_date: dt.datetime) -> dict:
    """
    Query data API for values between selected dates.
    Note: end_date exclusive.
    """

    start_date = dt.datetime.strftime(start_date, '%Y-%m-%d')
    end_date = dt.datetime.strftime(end_date, '%Y-%m-%d')

    try:
        data = requests.get(DATA_URL.format(start_date, end_date))
        return data.json()
    except:
        logging.error("Could not retrieve data")
        return {}      # <-------- new


@ac.enforce_type_hints   # <-------- new
def process_dict(data: dict):
    """
    Convert dict to df and replace values in 'sales' column.
    """
    # New section
    missing_fields = set({'sales', 'date'}).difference(data.keys())
    if len(missing_fields) > 0:    
        logging.error(f"Input missing these fields: {missing_fields}")
        return pd.DataFrame()

    df = pd.DataFrame({'sales': data['sales']}, index=data['date'])
    df_filt['sales'] = df['sales'].replace({'nan': pd.NA, '': pd.NA})

    return df_filt

We start by enforcing type hints on query_api and process_dict. In the except block in query_api, we now return an empty dictionary rather than None, to maintain a consistent output. Finally, in process_dict we check whether the input dictionary contains the necessary fields $-$ if not, we log the missing fields and return an empty dataframe.

There are plenty of other further rabbit holes we can pursue, such as whether run_pipeline should write an empty CSV or send some sort of alert that the job failed. But I’ll leave those up to you to decide!

Conclusions

This post introduced enforce_type_hints, a decorator for ensuring inputs to a function match their type hints. Raising an error the moment a function receives an input it wasn’t expecting might seem a little extreme, but it forces us to think critically about what exactly our functions are doing, and how upstream outputs get ingested downstream.

Interested in contributing? Fork the GitHub repo and submit a PR. There are plenty of other areas to explore $-$ I’m personally interested in a decorator for validating whether dataframe or dict inputs have the necessary fields (like lines 34-37 above).

Thanks for reading, and may your pipelines always do exactly what you intend.

Best,
Matt

Written on April 18, 2021