Simplifying
Assumptions

Less painful Recurrent Batch Normalization in Theano

by Tim Cooijmans on 2017-06-15

My debut Recurrent Batch Normalization works great but is impossible to implement without breaking absolutely all of your abstractions. Personally I’ve switched to Layer Normalization because it’s just as good and drastically simpler, but thanks to fellow MILA student Olivier Mastropietro I was reminded of a script I wrote back when I was working on recurrent batch normalization.

In particular, I was trying to shoehorn batch norm into Caglar Gulcehre’s Attentive Reader code. As always when working with Other People’s Code(tm), I try to avoid venturing into the bowels, so I wrote code that will take an arbitrary BN-enabled training graph and hack it up to get the corresponding inference graph.

The way it works is – as you construct your training graph, you label all tensors x that represent batch statistics by setting x.tag.bn_statistic = True. Then you train your model. Once the model is trained, you call the function get_inference_graph, giving it your model’s symbolic inputs, symbolic outputs, and an iterable over batches of training data. So get_inference_graph traverses your graph to find all batch statistics, creates a new graph that does the same thing but has all the batch statistics as additional outputs, and runs this graph on the provided training data to estimate population statistics. Then, it creates the inference graph, which is like the original graph but has batch statistics replaced by population statistics.

This would be trivial if it weren’t for Theano’s scan feature, because everything would be trivial if it weren’t for Theano’s scan feature. Scan encodes loops in the graph, and the way it does this is by representing its loop body as a completely separate “inner” graph, with “inner” inputs that correspond to “outer” inputs to the Scan node, and “inner” outputs that correspond to the “outer” outputs of the Scan node. They correspond but they are not connected, and so you can’t just take batch statistics from the “inner” graph and connect them to stuff in the “outer” graph.

In order to export the batch statistics to the “outer” graph so I can estimate population statistics, my code has to add them as extra Scan outputs. Once I get the population statistics, I need to import them back into the “inner” graph by adding appropriate Scan nitsots and mitmots and whatnots.

In the end, the code works for any level of nesting of Scan loops (including when you have no Scan at all). I suppose it would have been strategic to publicize this code back then, but perhaps it can still serve as a historical testament to the wicked rituals we had to go through back in the day of symbolic computation graphs. You know, before we all switched to PyTorch.

The code is here and a usage example is here. I reproduce my baby in all its glory below:

import sys, cPickle as pkl
import theano, itertools, pprint, copy, numpy as np, theano.tensor as T
from collections import OrderedDict
from theano.gof.op import ops_with_inner_function
from theano.scan_module.scan_op import Scan
from theano.scan_module import scan_utils

def equizip(*sequences):
    sequences = list(map(list, sequences))
    assert all(len(sequence) == len(sequences[0]) for sequence in sequences[1:])
    return zip(*sequences)

# get outer versions of the given inner variables of a scan node
def export(node, extra_inner_outputs):
    assert isinstance(node.op, Scan)

    # this is ugly but we can't use scan_utils.scan_args because that
    # clones the inner graph and then extra_inner_outputs aren't in
    # there anymore
    old_inner_inputs = node.op.inputs
    old_inner_outputs = node.op.outputs
    old_outer_inputs = node.inputs

    new_inner_inputs = list(old_inner_inputs)
    new_inner_outputs = list(old_inner_outputs)
    new_outer_inputs = list(old_outer_inputs)
    new_info = copy.deepcopy(node.op.info)

    # put the new inner outputs in the right place in the output list and
    # update info
    new_info["n_nit_sot"] += len(extra_inner_outputs)
    yuck = len(old_inner_outputs) - new_info["n_shared_outs"]
    new_inner_outputs[yuck:yuck] = extra_inner_outputs

    # in step 8, theano.scan() adds an outer input (being the actual
    # number of steps) for each nitsot. we need to do the same thing.
    # note these don't come with corresponding inner inputs.
    offset = (1 + node.op.n_seqs + node.op.n_mit_mot + node.op.n_mit_sot +
              node.op.n_sit_sot + node.op.n_shared_outs)
    # the outer input is just the actual number of steps, which is
    # always available as the first outer input.
    new_outer_inputs[offset:offset] = [new_outer_inputs[0]] * len(extra_inner_outputs)

    new_op = Scan(new_inner_inputs, new_inner_outputs, new_info)
    outer_outputs = new_op(*new_outer_inputs)

    # grab the outputs we actually care about
    extra_outer_outputs = outer_outputs[yuck:yuck + len(extra_inner_outputs)]
    return extra_outer_outputs

def gather_symbatchstats_and_estimators(outputs):
    symbatchstats = []
    estimators = []
    visited_scan_ops = set()

    for var in theano.gof.graph.ancestors(outputs):
        if hasattr(var.tag, "bn_statistic"):
            var.tag.original_id = id(var)
            symbatchstats.append(var)
            estimators.append(var)

        # descend into Scan
        try:
            op = var.owner.op
        except:
            continue
        if isinstance(op, Scan) and op not in visited_scan_ops:
            visited_scan_ops.add(op)
            print "descending into", var

            inner_estimators, inner_symbatchstats = gather_symbatchstats_and_estimators(op.outputs)
            outer_estimators = export(var.owner, inner_estimators)

            symbatchstats.extend(inner_symbatchstats)
            estimators.extend(outer_estimators)

    return symbatchstats, estimators

def get_population_outputs(batch_outputs, popstats):
    replacements = []
    visited_scan_ops = set()

    for var in theano.gof.graph.ancestors(batch_outputs):
        if hasattr(var.tag, "bn_statistic"):
            # can't rely on object identity because scan_args clones; use original_id
            popstat = next(popstat for batchstat, popstat in popstats.items() if batchstat.tag.original_id == var.tag.original_id)
            replacements.append((var, T.patternbroadcast(popstat, var.broadcastable)))

        # descend into Scan
        try:
            op = var.owner.op
        except:
            continue
        if isinstance(op, Scan):
            # this would cause multiple replacements for this variable
            assert not hasattr(var.tag, "bn_statistic")

            if op in visited_scan_ops:
                continue
            visited_scan_ops.add(op)
            print "descending into", var

            node = var.owner
            sa = scan_utils.scan_args(outer_inputs=node.inputs, outer_outputs=node.outputs,
                                      _inner_inputs=node.op.inputs, _inner_outputs=node.op.outputs,
                                      info=node.op.info)

            # add subscript as sequence
            # TODO check if this integer input drops the scan to cpu, if so use float and cast back in subtensor expression
            indices = T.arange(sa.n_steps)
            index = scan_utils.safe_new(indices[0])
            sa.outer_in_seqs.append(indices)
            sa.inner_in_seqs.append(index)

            # add popstats as nonsequences (because they may be shorter than len(indices))
            inner_popstats = {}
            for batchstat, outer_popstat in popstats.items():
                # this can't be subscripted hence won't appear in the inner graph
                if outer_popstat.ndim == 0:
                    continue

                inner_popstat = scan_utils.safe_new(outer_popstat)
                sa.outer_in_non_seqs.append(outer_popstat)
                sa.inner_in_non_seqs.append(inner_popstat)

                inner_popstats[batchstat] = theano.ifelse.ifelse(index < inner_popstat.shape[0],
                                                                 inner_popstat[index],
                                                                 inner_popstat[-1])

            # recurse on inner graph
            new_inner_outputs = sa.inner_outputs
            new_inner_outputs = get_population_outputs(new_inner_outputs, inner_popstats)

            # construct new scan node
            new_op = Scan(sa.inner_inputs, new_inner_outputs, sa.info)
            new_outer_outputs = new_op(*sa.outer_inputs)

            # there is one-to-one correspondence between old outer
            # inputs and new_outer_inputs; replace one-to-one
            replacements.extend(equizip(node.outputs, new_outer_outputs))

    print "replacements", replacements
    population_outputs = scan_utils.clone(batch_outputs, replace=replacements)
    return population_outputs

def get_inference_graph(inputs, batch_outputs, estimation_batches):
    symbatchstats, estimators = gather_symbatchstats_and_estimators(batch_outputs)
    print "symbatchstats x estimators", equizip(symbatchstats, estimators)

    if not symbatchstats:
        print "NO BATCH STATISTICS FOUND IN GRAPH"
    #assert symbatchstats

    def aggregate_varlen(aggregate, sample):
        # grow to accomodate shape
        aggregate = np.pad(aggregate,
                           [(0, max(0, sample.shape[j] - aggregate.shape[j]))
                            for j in range(aggregate.ndim)],
                           mode="constant")
        aggregate[tuple(map(slice, sample.shape))] += sample
        return aggregate

    # take average of batch statistics over estimation_batches
    estimator_fn = theano.function(inputs, estimators, on_unused_input="warn")
    batchstats = {}
    for i, batch in enumerate(estimation_batches):
        estimates = estimator_fn(**batch)
        for symbatchstat, estimator, estimate in equizip(symbatchstats, estimators, estimates):
            batchstats.setdefault(symbatchstat, []).append(estimate)

    popstats = {}
    coverages = {}
    for symbatchstat in symbatchstats:
        if batchstats[symbatchstat][0].ndim > 1:
            # assume first axis is time
            maxlen = max(map(len, batchstats[symbatchstat]))
            # pad all batch stats to maxlen by repeating last time step
            padded_batchstats = [
                np.pad(batchstat,
                       [(0, maxlen - len(batchstat))] + [(0, 0) for _ in range(1, batchstat.ndim)],
                       mode="edge")
                for batchstat in batchstats[symbatchstat]]
            popstat = sum(bs / len(padded_batchstats) for bs in padded_batchstats)

            coverages[symbatchstat] = (
                np.arange(maxlen)[None, :] <
                np.asarray(list(map(len, batchstats[symbatchstat])))[:, None]
            ).sum(axis=0)
        else:
            # not time-separated, just average as is
            popstat = sum(bs / len(batchstats[symbatchstat]) for bs in batchstats[symbatchstat])
        popstats[symbatchstat] = popstat

    if True:
        # allow inspection of all_stats
        import matplotlib.pyplot as plt
        for symbatchstat, popstat in popstats.items():
            if popstat.ndim == 1:
                plt.figure()
                plt.hist(popstat)
                plt.title(symbatchstat.tag.bn_label)
            elif False:
                plt.matshow(popstat, cmap="bone")
                plt.colorbar()
            else:
                choice = np.random.choice(popstat.shape[1], size=(min(20, popstat.shape[1]),), replace=False)
                fig, axes = plt.subplots(2, sharex=True)

                axes[0].plot(popstat[:, choice])
                axes[0].set_title("values")

                axes[1].plot(coverages[symbatchstat])
                axes[1].set_title("support")
                axes[1].set_ylabel("batches")
                axes[1].set_xlabel("time steps")

                fig.suptitle(symbatchstat.tag.bn_label)

            plt.savefig("%s.pdf" % symbatchstat.tag.bn_label, bbox_inches="tight")
        #plt.show()
        #import pdb; pdb.set_trace()
        pkl.dump(dict(batchstats=batchstats, coverages=coverages, popstats=popstats),
                 open("allstats.pkl", "wb"))

    sympopstats = {}
    for symbatchstat, popstat in popstats.items():
        # need as_tensor_variable to make sure it's not a CudaNdarray
        # because then the replacement will fail as symbatchstat may not
        # have been moved to the gpu yet.
        sympopstat = T.as_tensor_variable(theano.shared(popstat)).copy(name="popstat_%s" % symbatchstat.name)
        sympopstats[symbatchstat] = sympopstat

    population_outputs = get_population_outputs(batch_outputs, sympopstats)

    return population_outputs