JR
OS setting (details)

### Blog Posts # Training BNNs with HMC

Sep 5, 20193 min readPython, Machine Learning

## Contents

This post is a guide on how to use the Hamiltonian Monte Carlo (HMC) transition kernel provided by TensorFlow Probability to train Bayesian neural networks (BNN) by sampling from their posterior distribution. If you haven’t heard of HMC before, check out this short introduction. Or, if you just want the 30-second elevator pitch: HMC is a Markov chain Monte Carlo (MCMC) algorithm that avoids the random walk behavior of simpler MCMC methods such as Metropolis-Hastings or Gibbs sampling by using Hamilton’s equations of classical mechanics to take a series of first-order-gradient-informed steps through an auxiliary phase space which can be projected down onto the probability density you care about. This allows it to massively speed up mixing (i.e. generating a Markov chain of less correlated samples) and converge on even high-dimensional target distributions (although — here’s the upfront disclaimer — perhaps not the many tens of thousands to several million parameters of modern neural networks which result in a posterior distribution of equal dimensionality).

The code in this guide was written in TensorFlow (TF) v2.1 and TensorFlow Probability (TFP) v0.9.

## Simple Example

To first get a feel for what HMC is doing, let’s start with a very simple example and visualize it. Let’s define a bimodal distribution $\pi(\vec x)$ consisting of two Gaussians that are almost completely disjoint,

$\pi(\vec x) = \Ncal(\vec\mu_1, \mat\Sigma) + \Ncal(\vec\mu_2, \mat\Sigma)$

where

$\vec\mu_1 = (2,-2)^T, \quad \vec\mu_2 = (-2,2)^T, \quad \mat\Sigma = I_2 = \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix}.$

These sorts of distributions typically pose difficulties to MCMC algorithms since exploring them fully requires entering and passing through a region of low probability density, which by nature of the algorithm is a rare event.

Let’s see how many samples HMC requires to achieve mixing, i.e. jumping from whichever mode we start sampling first to the other. Once both modes are covered with a roughly equal number of samples, we can use the generated Markov chain $\Ccal$ to compute accurate expectation values with respect to $\pi$. For example, say we have an observable $f(\vec x)$ and were interested in the value it’s likely to take given the probabilities for different $\vec x$ under $\pi(\vec x)$,

$\expec_\pi[f] = \int_{\Rbb^2} \pi(\vec x) \, f(\vec x) \, \dif\vec x.$

If the Markov chain has converged, we can get an accurate estimate for this expectation value, by averaging over samples in the chain,

$\expec_\pi[f] \approx \hat f = \frac{1}{|\Ccal|} \sum_{\vec x \in \Ccal} f(\vec x),$

And here’s the code to generate $\Ccal$ using TFP:

simple-distros.py
import numpy as np
import plotly.graph_objects as go
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

normals_2d = [
tfd.MultivariateNormalDiag([2, -2], [1, 1]),
tfd.MultivariateNormalDiag([-2, 2], [1, 1]),
]
bimodal_gauss = tfd.Mixture(tfd.Categorical([1, 1]), normals_2d)

@tf.function
def sample_chain(*args, **kwargs):
"""Since this is bulk of the computation, using @tf.function
here to compile a static graph for tfp.mcmc.sample_chain
significantly improves performance.
"""
return tfp.mcmc.sample_chain(*args, **kwargs)

kernel = tfp.mcmc.NoUTurnSampler(bimodal_gauss.log_prob, step_size=1e-3)
kernel,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
step_size=new_step_size
),
step_size_getter_fn=lambda pkr: pkr.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
)

chain, trace, final_kernel_results = sample_chain(
num_results=150,
current_state=tf.constant([0.1, 0]),
return_final_kernel_results=True,
)

x = np.arange(-6, 6, 0.3)
domain = np.mgrid[-6:6:0.3, -6:6:0.3].reshape(2, -1).T

fig = go.Figure(
data=[
go.Surface(x=x, y=x, z=bimodal_gauss.prob(domain).numpy().reshape(len(x), -1)),
go.Scatter3d(x=chain[:, 0], y=chain[:, 1], z=bimodal_gauss.prob(chain)),
]
)
fig.update_layout(
height=700, title=go.layout.Title(text=f"step size: {step_size}", xref="paper", x=0)
)

Note that these awkward lambda functions passed to adaptive_kernel are only necessary due to temporary inconsistencies in TFP’s mcmc module and should become superfluous once that receives a good refactor as promised here. If you’re reading this post several months down the line, then probably all you’ll need to write is

adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
)

The resulting plot you get should look like this.

Bimodal distribution decorated by 100 HMC samples which managed to jump the gap between modes.

During the Hamiltonian evolution that’s simulated to obtain each new sample in the chain $\Ccal$, the potential energy $V$ corresponds to the negative log of the probability density shown above, $V(\vec x) = -\log \pi(\vec x)$. The two peaks thus correspond to more shallow holes in the otherwise flat ground that’s the arena of the Hamiltonian simulation. The starting point where the very first Hamiltonian simulation sets off is randomly initialized. It then takes a few iterations before the current head of the chain “falls” into one of the holes. Once there, Hamiltonian dynamics will likely stick there for a while and gather samples from that region of increased probability/decreased potential energy. To eventually transition between the two modes is only possible when starting out a Hamiltonian evolution with a particularly large initial momentum. The momenta are randomly sampled at the start of each new evolution, $\vec p \sim \Ncal(0, \mat\mat\Sigma)$ (independent of the current position). This ensures that lower energy points are more likely to be sampled since the Gaussian has a higher density for small values of $|\vec p|$. Only if we happen to sample a momentum that’s sufficiently high (and happen to be moving in the right direction) can we overcome the potential barrier between two modes.

## HMC Routines

Now let’s look at actually applying HMC to the parameters of a neural network which we assume to be Gaussian distributed. First, let’s define a few functions to handle the grunt work in this exercise.

src/bnn/hmc.py
import functools as ft

import tensorflow as tf
import tensorflow_probability as tfp

import src.bnn.functions as bnn_fn

def pre_train_nn(X_train, y_train, nodes_per_layer, epochs=100):
"""Pre-train NN to get good starting point for HMC.

Args:
nodes_per_layer (list): the number of nodes in each dense layer
X_train (Tensor or np.array): training samples
y_train (Tensor or np.array): training labels

Returns:
Tensor: list of tensors specifying the weights of the trained network
model: Keras Sequential model
"""
last_layer = nodes_per_layer.pop(-1)
model = tf.keras.Sequential()
for units in nodes_per_layer:

model.fit(X_train, y_train, epochs=epochs, verbose=0)
return [tf.convert_to_tensor(w) for w in model.get_weights()], model

def trace_fn(current_state, kernel_results, summary_freq=10, callbacks=[]):
"""Can be passed to the HMC kernel to obtain a trace of intermediate
kernel results and histograms of the network parameters in Tensorboard.
"""
step = kernel_results.step
with tf.summary.record_if(tf.equal(step % summary_freq, 0)):
for idx, tensor in enumerate(current_state, 1):
count = str(int(idx / 2) + 1)
name = "weights_" if idx % 2 == 0 else "biases_" + count
tf.summary.histogram(name, tensor, step=step)
return kernel_results, [cb(*current_state) for cb in callbacks]

# @tf.function
def sample_chain(*args, **kwargs):
"""Compile static graph for tfp.mcmc.sample_chain.
Since this is bulk of the computation, using @tf.function here
significantly improves performance (empirically about ~5x).
"""
return tfp.mcmc.sample_chain(*args, **kwargs)

def run_hmc(
target_log_prob_fn,
step_size=0.01,
num_leapfrog_steps=10,
num_burnin_steps=1000,
num_results=1000,
current_state=None,
resume=None,
log_dir="logs/hmc/",
sampler="nuts",
**kwargs,
):
"""Creates an adaptive HMC kernel and generates a Markov chain of length num_results
by performing gradient-informed Hamiltonian Monte Carlo transitions. Either the new
or current position in parameter space is appended to the chain after each transition
depending on a Metropolis accept/reject.

Args:
target_log_prob_fn {callable}: Determines the stationary distribution
the Markov chain should converge to.

Returns:
burnin(s), chain(s), trace, final_kernel_result: Discarded samples generated during warm-up,
the Markov chain(s), the trace generated by trace_fn and the kernel results of the last step
(in case the computation needs to be resumed).
"""
err = "Either current_state or resume is required when calling run_hmc"
assert current_state is not None or resume is not None, err

summary_writer = tf.summary.create_file_writer(log_dir)

if sampler == "nuts":
kernel = tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size)
kernel,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
step_size=new_step_size
),
step_size_getter_fn=lambda pkr: pkr.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
)
elif sampler == "hmc":
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
)
)

if resume:
prev_chain, prev_trace, prev_kernel_results = resume
step = len(prev_chain)
current_state = tf.nest.map_structure(lambda chain: chain[-1], prev_chain)
else:
step = 0

tf.summary.trace_on(graph=True, profiler=False)

chain, trace, final_kernel_results = sample_chain(
current_state=current_state,
previous_kernel_results=prev_kernel_results,
num_results=num_results + num_burnin_steps,
trace_fn=ft.partial(trace_fn, summary_freq=20),
return_final_kernel_results=True,
**kwargs,
)

with summary_writer.as_default():
tf.summary.trace_export(name="hmc_trace", step=step)
summary_writer.close()

if resume:
chain = nest_concat(prev_chain, chain)
trace = nest_concat(prev_trace, trace)
burnin, samples = zip(*[(t[:-num_results], t[-num_results:]) for t in chain])
return burnin, samples, trace, final_kernel_results

def predict_from_chain(chain, X_test, uncertainty="aleatoric+epistemic"):
"""Takes a Markov chain of NN configurations and does the actual
prediction on a test set X_test including aleatoric and optionally
epistemic uncertainty estimation.
"""
err = f"unrecognized uncertainty type: {uncertainty}"
assert uncertainty in ["aleatoric", "aleatoric+epistemic"], err

if uncertainty == "aleatoric":
post_params = [tf.reduce_mean(t, axis=0) for t in chain]
post_weights, post_biases = post_params[::2], post_params[1::2]
post_model = bnn_fn.build_network(post_weights, post_biases)
y_pred, y_var = post_model(X_test, training=False)

return y_pred.numpy(), y_var.numpy()

if uncertainty == "aleatoric+epistemic":
restructured_chain = [
[tensor[i] for tensor in chain] for i in range(len(chain))
]

def predict_from_sample(sample):
post_weights, post_biases = sample[::2], sample[1::2]
post_model = bnn_fn.build_network(post_weights, post_biases)
y_pred, y_var = post_model(X_test, training=False)
return y_pred, y_var

y_pred_mc_samples, y_var_mc_samples = zip(
*[predict_from_sample(sample) for sample in restructured_chain]
)
y_pred, y_var_epist = tf.nn.moments(
tf.convert_to_tensor(y_pred_mc_samples), axes=0
)
y_var_aleat = tf.reduce_mean(tf.convert_to_tensor(y_var_mc_samples), axis=0)
y_var_tot = y_var_epist + y_var_aleat
return y_pred.numpy(), y_var_tot.numpy()

def hmc_predict(
weight_prior, bias_prior, init_state, X_train, y_train, X_test, y_test=None, **kwds
):
"""Top-level function that ties together run_hmc and predict_from_chain by accepting
a train and test set plus parameter priors to construct the BNN's log probability
function given the training data X_train, y_train.
"""
default = dict(
)
kwds = {**default, **kwds}
bnn_log_prob_fn = bnn_fn.target_log_prob_fn_factory(
weight_prior, bias_prior, X_train, y_train
)
burnin, samples, trace, final_kernel_results = run_hmc(
bnn_log_prob_fn, current_state=init_state, **kwds
)

y_pred, y_var = predict_from_chain(samples, X_test)
return y_pred, y_var, final_kernel_results

def nest_concat(*args, axis=0):
"""Utility function for concatenating a new Markov chain or trace with
older ones when resuming a previous calculation.
"""
return tf.nest.map_structure(lambda *parts: tf.concat(parts, axis=axis), *args)

def ess(chains, **kwargs):
"""Measure effective sample size of Markov chain(s).

Arguments:
chains {Tensor or list of Tensors}: If list, first
dimension should index identically distributed states.
"""
return tfp.mcmc.effective_sample_size(chains, **kwargs)

def r_hat(tensors):
"""See https://tensorflow.org/probability/api_docs/python/tfp/mcmc/potential_scale_reduction.
"""
return [tfp.mcmc.diagnostic.potential_scale_reduction(t) for t in tensors]

def get_num_chains(target_log_prob_fn, current_state):
"""Check how many chains your kernel thinks it's dealing
with. Can help with debugging.
"""
return tf.size(target_log_prob_fn(*current_state)).numpy()

At the very top, we import a module that contains functions pertaining to Bayesian neural networks more generally, not just when training with HMC. Here’s what that looks like.

## BNN Routines

src/bnn/bnn.py
import functools as ft

import tensorflow as tf
import tensorflow_probability as tfp

def dense(inputs, weights, biases, activation):
return activation(tf.matmul(inputs, weights) + biases)

def build_network(weights_list, biases_list, activation=tf.nn.relu):
def model(samples, training=True):
net = samples
for (weights, biases) in zip(weights_list[:-1], biases_list[:-1]):
net = dense(net, weights, biases, activation)
# final linear layer
net = tf.matmul(net, weights_list[-1]) + biases_list[-1]
# the model's predictive mean and log variance (each of size samples.shape(0))
y_pred, y_log_var = tf.unstack(net, axis=1)
if training:
return tfp.distributions.Normal(
loc=y_pred, scale=tf.sqrt(tf.exp(y_log_var))
)
else:
return y_pred, tf.exp(y_log_var)

return model

def bnn_log_prob_fn(X, y, weights, biases, get_mean=False):
"""Compute log likelihood of predicted labels y given the
features X, weights W and biases b.

Args:
X (np.array): 2d feature values.
y (np.array): 1d labels (ground truth).
weights (list): 2d arrays of weights for each layer.
biases (list): 1d arrays of biases for each layer.
get_mean (bool, optional): Whether to return the mean log
probability over all labels for diagnostics, e.g. to
compare train and test set performance. Defaults to False.

Returns:
tf.tensor: Sum or mean of log probabilities of all labels.
"""
network = build_network(weights, biases)
labels_dist = network(X)
if get_mean:
return tf.reduce_mean(labels_dist.log_prob(y))
return tf.reduce_sum(labels_dist.log_prob(y))

def prior_log_prob_fn(weight_prior, bias_prior, weights, biases):
log_prob = sum([tf.reduce_sum(weight_prior.log_prob(w)) for w in weights])
log_prob += sum([tf.reduce_sum(bias_prior.log_prob(b)) for b in biases])
return log_prob

def target_log_prob_fn_factory(weight_prior, bias_prior, X_train, y_train):
def target_log_prob_fn(*args):
weights, biases = args[::2], args[1::2]
log_prob = prior_log_prob_fn(weight_prior, bias_prior, weights, biases)
log_prob += bnn_log_prob_fn(X_train, y_train, weights, biases)
return log_prob

return target_log_prob_fn

def tracer_factory(X, y):
return lambda *args: ft.partial(bnn_log_prob_fn, X, y, get_mean=True)(
args[::2], args[1::2]
).numpy()

def get_random_initial_state(weight_prior, bias_prior, nodes_per_layer, overdisp=1.0):
"""Generate random initial configuration for weights and biases of a fully-connected NN
sampled according to the specified prior distributions. This configuration can serve
as a starting point for instance to generate a Markov chain of network configurations
via Hamiltonian Monte Carlo which are distributed according to the posterior after having
observed some data.
"""
init_state = []
for idx in range(len(nodes_per_layer) - 1):
weights_shape = (nodes_per_layer[idx], nodes_per_layer[idx + 1])
biases_shape = nodes_per_layer[idx + 1]
# use overdispersion > 1 for better R-hat statistics
weights = weight_prior.sample(tf.squeeze(weights_shape)) * overdisp
biases = bias_prior.sample(tf.squeeze(biases_shape)) * overdisp
init_state.extend((weights, biases))
return init_state

## Boston Housing

Finally, let’s look at a concrete example by training a network with HMC to predict Boston housing prices and compare its performance with a maximum a-posteriori (MAP) network (i.e. just maximum likelihood with a bit of regularization due to priors). (The following is a VS Code-style Jupyter notebook, the # %% are cell delimiters.)

src/notebooks/hmc/boston.py
# %%
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

import src.bnn.hmc as hmcnn
import src.bnn.map as mapnn

# %%
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.boston_housing.load_data()
X_train, y_train, X_test, y_test = [
arr.astype("float32") for arr in [X_train, y_train, X_test, y_test]
]

# %%
weight_prior = tfp.distributions.Normal(0, 0.2)
bias_prior = tfp.distributions.Normal(0, 0.2)
map_y_pred, map_y_var, map_log_probs, map_final_state = mapnn.map_predict(
weight_prior, bias_prior, X_train, y_train, X_test, y_test
)

# %%
map_mae = (np.abs(map_y_pred - y_test)).mean()

# %%
hmc_y_pred, hmc_y_var, _ = hmcnn.hmc_predict(
weight_prior, bias_prior, map_final_state, X_train, y_train, X_test, y_test
)

# %%
hmc_mae = (np.abs(hmc_y_pred - y_test)).mean()

and the MAP module imported above is as follows.

src/bnn/map.py
import numpy as np
import tensorflow as tf

import src.bnn.functions as bnn_fn

def get_map_trace(
target_log_prob_fn, state, num_iters=1000, save_every=10, callbacks=[]
):
state_vars = [tf.Variable(s) for s in state]

def map_loss():
return -target_log_prob_fn(*state_vars)

@tf.function
def minimize():
optimizer.minimize(map_loss, state_vars)

state_trace, cb_trace = [[] for _ in state], [[] for _ in callbacks]
for i in range(num_iters):
if i % save_every == 0:
for trace, state in zip(state_trace, state_vars):
trace.append(state.numpy())
for trace, cb in zip(cb_trace, callbacks):
trace.append(cb(*state_vars))
minimize()
return state_trace, cb_trace

def get_best_map_state(map_trace, map_log_probs):
# map_log_probs contains the log probability
# trace for both train  and test set .
test_set_max_log_prob_idx = np.argmax(map_log_probs)
# We return the MAP NN configuration that achieved the
# highest likelihood on the test set.
return [tf.constant(tr[test_set_max_log_prob_idx]) for tr in map_trace]

def get_nodes_per_layer(n_features, net_taper=(1, 0.5, 0.2, 0.1)):
nodes_per_layer = [int(n_features * x) for x in net_taper]
# Ensure the last layer has two nodes so that output can be split into
# predictive mean and learned loss attenuation (see eq. (7) of
# https://arxiv.org/abs/1703.04977) which the network learns individually.
nodes_per_layer.append(2)
return nodes_per_layer

def map_predict(weight_prior, bias_prior, X_train, y_train, X_test, y_test):
"""Generate maximum a posteriori neural network predictions.

Args:
weight_prior (tfp.distribution): Prior probability for the weights
bias_prior (tfp.distribution): Prior probability for the biases
[X/y_train/test] (np.arrays): Train and test sets
"""

log_prob_tracers = (
bnn_fn.tracer_factory(X_train, y_train),
bnn_fn.tracer_factory(X_test, y_test),
)

_, n_features = X_train.shape  # number of samples times number of features
random_initial_state = bnn_fn.get_random_initial_state(
weight_prior, bias_prior, get_nodes_per_layer(n_features)
)

trace, log_probs = get_map_trace(
bnn_fn.target_log_prob_fn_factory(weight_prior, bias_prior, X_train, y_train),
random_initial_state,
num_iters=3000,
callbacks=log_prob_tracers,
)
# Can be used as initial configuration for other methods such as Hamiltonian Monte Carlo.
best_params = get_best_map_state(trace, log_probs)

weights, biases = best_params[::2], best_params[1::2]
model = bnn_fn.build_network(weights, biases)
y_pred, y_var = model(X_test, training=False)
return y_pred.numpy(), y_var.numpy(), log_probs, best_params

## Things to look forward to

TFP is still in the process of adding more functionality to its mcmc module. For instance, both the NoUTurnSampler and the DualAveragingStepSizeAdaptation appeared only shortly before this post. They’re both great improvements over the regular HamiltonianMonteCarlo kernel and SimpleStepSizeAdaptation, respectively. The former should allow the kernel to converge to optimal step sizes faster while the latter enables it to scale to even larger problems in terms of state space dimensionality. I, for one, am excited to see where this is headed.