Making Tensorflow Faster

Speeding up Tensorflow

In writing my previous posts, I stumbled upon a few things that made code run significantly faster in tensorflow. This summarizes my notes on these techniques. To a tensorflow afficionado these things are probably considered well-known, but for a the uninitiated I was suprised at the kind of difference just a little bit of extra code could make. Here we examine these techniques using the example from a previous post the techniques are:

  • Wrapping a tensorflow_probability.mcmc.sample_chain in a tensorflow function
  • Using XLA for your tensorflow function

To start, we generate toy data for a simple ordinary least squares regression problem (to see more, see my previous tensorflow post).

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore", category=Warning) 
# set seed so results never change
np.random.seed(1234)

Here we generate data for \(N=500\) and \(K=2\):

# set tensorflow data type
dtype = tf.float32

##
## simple OLS Data Generation Process
##
# True beta
b = np.array([10, -1])
N = 500
# True error std deviation
sigma_e = 1

x = np.c_[np.ones(N), np.random.randn(N)]
y = x.dot(b) + sigma_e * np.random.randn(N)

And convert the data to tensors and setup the log-likelihood for this problem:

X = tf.constant(x, dtype=dtype)
Y = tf.constant(y, dtype=dtype)
pi = tf.constant(np.pi, dtype=dtype)

def ols_loglike(beta, sigma):
    # xb (mu_i for each observation)
    mu = tf.linalg.matvec(X, beta)
    # this is normal pdf logged and summed over all observations
    ll = - (X.shape[0]/2.)*tf.math.log(2.*pi*sigma**2) -\
	    (1./(2.*sigma**2.))*tf.math.reduce_sum((Y-mu)**2., axis=-1)
    return ll
# Out [7]: 

leads to BIG speed increases (as the wrapped versions runs on multiple cores). Below we provide a quick speed comparison (notebook CPU with 8 cores).

Tensor Function Wrapper

As I demonstrated in the earlier post, it is straightforward to setup

Let's run the same model without the wrapping technique. Resetting the kernels:

# a naiive initial value for chain (for beta and sigma):
init = [tf.constant([0., 0.], dtype=dtype), tf.constant(1.,dtype=dtype)]
samples = 2000
burnin = 500
init_step_size=.3

nuts_kernel = tfp.mcmc.NoUTurnSampler(
	target_log_prob_fn=ols_loglike, 
	step_size=init_step_size,
	)
adapt_nuts_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=nuts_kernel,
    num_adaptation_steps=burnin,
    step_size_getter_fn=lambda pkr: pkr.step_size,
    log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
    step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(step_size=new_step_size)
   )

And running the time consuming part:

%%timeit -n1  -r1
tfp.mcmc.sample_chain(
    num_results=samples,
    current_state=init,
    kernel=adapt_nuts_kernel,
    num_burnin_steps=100,
    parallel_iterations=5)

If, we wrap the sampler as a tensorflow function, we get dramatic speedups:

@tf.function
def sampler(init_vals):

    @tf.function
    def ols_loglike(beta, sigma):
	# xb (mu_i for each observation)
	mu = tf.linalg.matvec(X, beta)
	# this is normal pdf logged and summed over all observations
	ll = - (X.shape[0]/2.)*tf.math.log(2.*pi*sigma**2) -\
	      (1./(2.*sigma**2.))*tf.math.reduce_sum((Y-mu)**2., axis=-1)
	return ll

    nuts_kernel = tfp.mcmc.NoUTurnSampler(
	target_log_prob_fn=ols_loglike, 
	step_size=init_step_size,
	)
    adapt_nuts_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
	inner_kernel=nuts_kernel,
	num_adaptation_steps=burnin,
	step_size_getter_fn=lambda pkr: pkr.step_size,
	log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
	step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(step_size=new_step_size)
	)
    sample_vals, stats = tfp.mcmc.sample_chain(num_results=samples,
					   current_state=init_vals,
					   kernel=adapt_nuts_kernel,
					   num_burnin_steps=100,
					   parallel_iterations=5)
    return sample_vals, stats

Checkout the sampler:

type(sampler)

Let's sample from the function:

%%timeit -n1 -r1
sampler(init)

That is a BIG speedup (it is ~14x faster) just by wrapping your code in a tensorflow function.

XLA Mode

Perhaps we can do even better using the new XLA compiler for our tensorflow function. This is experimental, but let's try it.

@tf.function(experimental_compile=True)
def sampler(init_vals):

    @tf.function(experimental_compile=True)
    def ols_loglike(beta, sigma):
	# xb (mu_i for each observation)
	mu = tf.linalg.matvec(X, beta)
	# this is normal pdf logged and summed over all observations
	ll = - (X.shape[0]/2.)*tf.math.log(2.*pi*sigma**2) -\
	      (1./(2.*sigma**2.))*tf.math.reduce_sum((Y-mu)**2., axis=-1)
	return ll

    nuts_kernel = tfp.mcmc.NoUTurnSampler(
	target_log_prob_fn=ols_loglike, 
	step_size=init_step_size,
	)
    adapt_nuts_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
	inner_kernel=nuts_kernel,
	num_adaptation_steps=burnin,
	step_size_getter_fn=lambda pkr: pkr.step_size,
	log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
	step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(step_size=new_step_size)
	)
    sample_vals, stats = tfp.mcmc.sample_chain(num_results=samples,
					   current_state=init_vals,
					   kernel=adapt_nuts_kernel,
					   num_burnin_steps=100,
					   parallel_iterations=5)
    return sample_vals, stats
%%timeit -n1 -r1
sampler(init)