Type inference is a major feature of several programming languages, most notably

languages from the ML family like Haskell. In this post I want to provide a

brief overview of type inference, along with a simple Python implementation for

a toy ML-like language.

## Uni-directional type inference

While static typing is very useful, one of its potential downsides is verbosity.

The programmer has to annotate values with types throughout the code, which

results in more effort and clutter. What’s really annoying, though, is that in

many cases these annotations feel superfluous. Consider this classical C++

example from pre-C++11 times:

```
std::vector<Blob*> blobs;
std::vector<Blob*>::iterator iter = blobs.begin();
```

Clearly when the compiler sees `blobs.begin()`, it knows the type of

`blobs`, so it also knows the type of the `begin()` method invoked on it

because it is familiar with the declaration of `begin`. Why should the

programmer be burdened with spelling out the type of the iterator? Indeed, one

of the most welcome changes in C++11 was lifting this burden by repurposing

`auto` for basic type inference:

```
std::vector<Blob*> blobs;
auto iter = blobs.begin();
```

Go has a similar capability with the `:=` syntax. Given some function:

```
func parseThing(...) (Node, error) {
}
```

We can simply write:

```
node, err := parseThing(...)
```

Without having to explicitly declare that `node` has type `Node` and `err`

has type `error.`

These features are certainly useful, and they involve some degree of type

inference from the compiler. Some functional programming proponents say this is

not *real* type inference, but I think the difference is just a matter of

degree. There’s certainly *some* inference going on here, with the compiler

calculating and assigning the right types for expressions without the

programmer’s help. Since this calculation flows in one direction (from the

declaration of the `vector::begin` method to the `auto` assignment), I’ll

call it *uni-directional* type inference [1].

## Bi-directional type inference (Hindley-Milner)

If we define a new `map` function in Haskell to map a function over a list,

we can do it as follows:

```
mymap f [] = []
mymap f (first:rest) = f first : mymap f rest
```

Note that we did not specify the types for either the arguments of

`mymap`, or its return value. The Haskell compiler can infer them on its own,

using the definition provided:

```
> :t Main.mymap
Main.mymap :: (t1 -> t) -> [t1] -> [t]
```

The compiler has determined that the first argument of `mymap` is a generic

function, assigning its argument the type `t1` and its return value the type

`t`. The second argument of `mymap` has the type `[t1]`, which means “list

of `t1`“; then the return value of `mymap` has the type “list of `t`“.

How was this accomplished?

Let’s start with the second argument. From the `[] = []` variant, and also

from the `(first:rest)` deconstruction, the compiler infers it has a list

type. But there’s nothing else in the code constraining the element type, so the

compiler chooses a generic type specifier – `t1`. `f first` applies `f` to

an element of this list, so `f` has to take `t1`; nothing constrains its

return value type, so it gets the generic `t`. The result is `f` has type

`(t1 -> t)`, which in Haskell parlance means “a function from `t1` to

`t`“.

Here is another example, written in a toy language I put together for the sake

of this post. The language is called **microml**, and its implementation is

described at the end of the post:

```
foo f g x = if f(x == 1) then g(x) else 20
```

Here `foo` is declared as a function with three arguments. What is its type?

Let’s try to run type inference manually. First, note that the body of the

function consists of an `if` expresssion. As is common in programming

languages, this one has some strict typing rules in microml; namely, the type of

the condition is boolean (`Bool`), and the types of the `then` and `else`

clauses must match.

So we know that `f(x == 1)` has to return a `Bool`. Moreover, since `x` is

compared to an integer, `x` is an `Int`. What is the type of `g`? Well, it

has an `Int` argument, and it return value must match the type of the `else`

clause, which is an `Int` as well.

To summarize:

- The type of
`x`is`Int` - The type of
`f`is`Bool -> Bool` - The type of
`g`is`Int -> Int`

So the overall type of `foo` is:

```
((Bool -> Bool), (Int -> Int), Int) -> Int
```

It takes three arguments, the types of which we have determined, and returns

an `Int`.

Note how this type inference process is not just going in one direction, but

seems to be “jumping around” the body of the function figuring out known types

due to typing rules. This is why I call it bi-directional type inference,

but it’s much better known as Hindley-Milner type inference, since it was

independently discovered by Roger Hindley in 1969 and Robin Milner in 1978.

## How Hindley-Milner type inference works

We’ve seen a couple of examples of manually running type inference on some code

above. Now let’s see how to translate it to an implementable algorithm. I’m

going to present the process in several separate stages, for simplicity. Some

other presentations of the algorithm combine several of these stages, but seeing

them separately is more educational, IMHO.

The stages are:

- Assign symbolic type names (like
`t1`,`t2`, …) to all subexpressions. - Using the language’s typing rules, write a list of
*type equations*(or

*constraints*) in terms of these type names. - Solve the list of type equations using unification.

Let’s use this example again:

```
foo f g x = if f(x == 1) then g(x) else 20
```

Starting with **stage 1**, we’ll list all subexpressions in this

declaration (starting with the declaration itself) and assign unique type names

to them:

```
foo t0
f t1
g t2
x t3
if f(x == 1) then g(x) else 20 t4
f(x == 1) t5
x == 1 t6
x t3
g(x) t7
20 Int
```

Note that every subexpression gets a type, and we de-duplicate them (e.g. `x`

is encountered twice and gets the same type name assigned). Constant nodes get

known types.

In **stage 2**, we’ll use the language’s typing rules to write down equations

involving these type names. Usually books and papers use slightly scary formal

notation for typing rules; for example, for `if`:

[frac{Gamma vdash e_0 : Bool, Gamma vdash e_1 : T, Gamma vdash e_2 : T}{Gamma vdash if: e_0: then: e_1: else: e_2 : T}]

All this means is the intuitive typing of `if` we’ve described above: the

condition is expected to be boolean, and the types of the `then` and `else`

clauses are expected to match, and their type becomes the type of the whole

expression.

To unravel the notation, prepend “given that” to the expression above the line

and “we can derive” to the expression below the line;

Gamma vdash e_0 : Bool means that e_0 is typed to Bool in

the set of typing assumptions called Gamma.

Similarly, a typing rule for single-argument function application would be:

[frac{Gamma vdash e_0 : T, Gamma vdash f : T rightarrow U}{Gamma vdash f(e_0) : U}]

The real trick of type inference is running these typing rules *in reverse*. The

rule tells us how to assign types to the whole expression given its constituent

types, but we can also use it as an equation that works both ways and lets us

infer constituent types from the whole expression’s type.

Let’s see what equations we can come up with, looking at the code:

From `f(x == 1)` we infer `t1 = (t6 -> t5)`, because `t1` is the type of

`f`, `t6` is the type of `x == 1`, and `t5` is the type of `f(x ==
1)`. Note that we’re using the typing rules for function application here.

Moreover, we can infer that

`t3`is

`Int`and

`t6`is

`Bool`because

of the typing rule of the

`==`operator.

Similarly, from `g(x)` we infer `t2 = (t3 -> t7)`.

From the `if` expression, we infer that `t6` is `Bool` (since it’s the

condition of the `if`) and that `t4 = Int`, because the `then` and

`else` clauses must match.

Now we have a list of equations, and our task is to find the most general

solution, treating the equations as constraints. This is done by using the

unification algorithm which I described in detail in the previous post. The solution we’re seeking

here is precisely the *most general unifier*.

For our expression, the algorithm will find the type of `foo` to be:

```
((Bool -> Bool), (Int -> Int), Int) -> Int)
```

As expected.

If we make a slight modification to the expression to remove the comparison of

`x` with 1:

```
foo f g x = if f(x) then g(x) else 20
```

Then we can no longer constrain the type of `x`, since all we know about it

is that it’s passed into functions `f` and `g`, and nothing else constrains

the arguments of these functions. The type inference process will thus calculate

this type for `foo`:

```
((a -> Bool), (a -> Int), a) -> Int
```

It assigns `x` the generic type name `a`, and uses it for the arguments of

`f` and `g` as well.

## The implementation

An implementation of microml is available here, as

a self-contained Python program that parses a microml declaration and infers its

type. The best starting point is `main.py`, which spells out the stages of

type inference:

```
code = 'foo f g x = if f(x == 1) then g(x) else 20'
print('Code', '----', code, '', sep='n')
# Parse the microml code snippet into an AST.
p = parser.Parser()
e = p.parse_decl(code)
print('Parsed AST', '----', e, '', sep='n')
# Stage 1: Assign symbolic typenames
typing.assign_typenames(e.expr)
print('Typename assignment', '----',
typing.show_type_assignment(e.expr), '', sep='n')
# Stage 2: Generate a list of type equations
equations = []
typing.generate_equations(e.expr, equations)
print('Equations', '----', sep='n')
for eq in equations:
print('{:15} {:20} | {}'.format(str(eq.left), str(eq.right), eq.orig_node))
# Stage 3: Solve equations using unification
unifier = typing.unify_all_equations(equations)
print('', 'Inferred type', '----',
typing.get_expression_type(e.expr, unifier, rename_types=True),
sep='n')
```

This will print out:

```
Code
----
foo f g x = if f(x == 1) then g(x) else 20
Parsed AST
----
Decl(foo, Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20)))
Typename assignment
----
Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20)) t0
If(App(f, [(x == 1)]), App(g, [x]), 20) t4
App(f, [(x == 1)]) t5
f t1
(x == 1) t6
x t3
1 Int
App(g, [x]) t7
g t2
x t3
20 Int
Equations
----
Int Int | 1
t3 Int | (x == 1)
Int Int | (x == 1)
t6 Bool | (x == 1)
t1 (t6 -> t5) | App(f, [(x == 1)])
t2 (t3 -> t7) | App(g, [x])
Int Int | 20
t5 Bool | If(App(f, [(x == 1)]), App(g, [x]), 20)
t4 t7 | If(App(f, [(x == 1)]), App(g, [x]), 20)
t4 Int | If(App(f, [(x == 1)]), App(g, [x]), 20)
t0 ((t1, t2, t3) -> t4) | Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20))
Inferred type
----
(((Bool -> Bool), (Int -> Int), Int) -> Int)
```

There are many more examples of type-inferred microml code snippets in the test

file `test_typing.py`. Here’s another example which is interesting:

```
> foo f x = if x then lambda t -> f(t) else lambda j -> f(x)
((Bool -> a), Bool) -> (Bool -> a)
```

The actual inference is implemented in `typing.py`, which is fairly well

commented and should be easy to understand after reading this post. The

trickiest part is probably the unification algorithm, but that one is just a

slight adaptation of the algorithm presented in the previous post.

[1] |
After this post was published, it was pointed out that another type I’ll emphasize that my only use of the term “bi-directional” is to |