Interactive Sampling for High Dimensional Markov Chains

In this blog post I outline some methods I have used for MCMC sampling with problems having many random variables. These are the kind of problems where memory requirements (and even disk space requirements) maybe too large to sample long enough to attain convergence. Suppose your model has lots of parameters, like 350,000, and the majority are nuisance parameters that are being used to integrate out some error process. Suppose you'd like to take 100,000 samples. If your storage format is 64bit floats, that will require 250 gigabytes of disk or memory storage to save the chain. Strategies for reducing the storage requirements:

  1. Thinning. Here you only include every \(k\) sample. If \(k\) is 10, we would only be able to keep 10,000 of the 100,000 samples. Even if you toss 90% of your samples, you will still need 25gb of storage space.
  2. Selective storage. Here you only keep the main model parameters and only keep the current values of the nuisance parameters for the next sample. Supposing that there are 250 main model parameters, this would require .2 gigabytes. Much more manageable. The problem with this approach is that if you find that you need to keep sampling, you are unable to do so because you don't have the full chain state for the 100,000th sample. So to increase your sample size, you would need to restart sampling and hope you sample long enough.
  3. Selective thinning. Here you partition your parameter space into main model parameters \(\beta\) and nuisance parameters \(\gamma\) as before. The idea is to keep only the main model parameters for the main model parameters and store the full chain state (both the main model parameters and the nuisance parameters) only when sampling is finished. This requires a much smaller memory footprint like (2) but allows the chain to be restarted as if we were storing the full chain state at every Sample.

PyMC3 allows for (1) and (2) above but as far as I know not (3). This posts illustrates selective thinning for a toy problem and uses MariaDB for backend storage.

import pandas as pd
import sqlalchemy as sqlalchemy
from orgtools import table
import pymc3 as pm3
import numpy as np
import hashlib as hashlib
import matplotlib.pyplot as plt

# some data
x = np.random.randn(1000)

Here is the basic model we will use to illustrate these concepts:

# initialize model
model = pm3.Model()

with model:
   mu = pm3.Flat('mu')
   sigma = pm3.Uniform('sigma',lower=.1,upper=5)
   gamma = pm3.Normal('gamma',mu=0, sd=1, shape=10)
   like = pm3.Normal('Likelihood',mu=mu,sd=sigma,observed=x)

Note that the random variables gamma aren't attached to the model and should be random variates from a standard normal and are included here only for demonstration purposes as stand-ins for nuisance parameters. For our purposes, I consider the nuisance parameters to be \(\gamma\) and main model parameters \(\mu\) and \(\sigma\). Lets take some Samples:

chain_length = 10000

with model:
    step = pm3.Metropolis()
    trace = pm3.sample(chain_length, step=step)

Let's examine the trace and notice it contains our main model parameters (labeled as mu and sd) and the nuisance parameters (labeled as gamma)

pm3.traceplot(trace);
plt.savefig('/tmp/trace.png')
'/tmp/trace.png'

Examining the trace, you can see we have both main and model parameters for the last 5 samples:

table(pm3.trace_to_dataframe(trace).tail(5),5)

Note, since we don't care about the variables \(\gamma_1\) through \(\gamma_10\). Consequently, an analysis requiring very long traces and/or traces having many nuisance variables can tax computer memory and/or disk space capacities as I argued above. But to be able to continue sampling, we need to have the current chain state for both the main and nuisance parameters. So in the code below we proceed as follows:

  1. Define \(S\) as the number of posterior samples between interactive breaks in sampling.
  2. Define \(B\) as the number of interactive breaks in sampling.
  3. Define start values as chain_state if \(b>1\) or set to some arbitrary starting value if \(b=1\).
  4. Begin sampling at start values and continue until \(S\) samples have been taken.
  5. Defining the trace \(T_b\) as the full trace and \(T_{mb}\) as the trace of the main model parameters. Both of these traces will have \(S\) samples.
  6. For each break \(b \in B\), store \(T_{mb}\) on disk (or in a database), and store the last sample from \(T_b\) and define as chain_state.

For our example above, we would store these values for the current chain state:

chain_state = pm3.trace_to_dataframe(trace).tail(1)
table(chain_state,1)

And the full 10000 samples for our (in this case) two main model parameters:

trace_mb = pm3.trace_to_dataframe(trace)[['mu','sigma']]
trace_mb.shape

Further, note that the last sample in trace_mb coincides with chain_state

table(trace_mb.tail(1),5)

Note that this method will record the trace history of main model parameters for \(B \times S\) samples but will only record the full trace history \(B\) times. In effect this does 2 things:

  1. Selectively thins the model parameters we don't care about
  2. Allows sampling to continue since the full model trace is always known for the last sample taken.

These features allow us to restart sampling and cut down on storage requirements for parameters we don't care about. For implementing this, we will store all trace information in a MariaDB database.

Setup Connection to Database   crypt

–—BEGIN PGP MESSAGE–— Version: GnuPG v2

hQIMA7PU0BRFy/sQARAAsmTkXucIUHI7KS2HHp0tzw2QSVSQzUpccwyleWl5K1pU aT3+XCDgw9tShM+Ic7b8H6W2zmvECTqw51NVf1TOa3OfhacY0TYXHgq6MpFVtV6I DVX4I21MZRq6nl4myPisZf1exOUg6QYFD/Z6qmzNEgln2E4juYs0rgwqFj5Jd7Ig BWn5pVKngDw3EgPOx6VxmB/fuCyACPK3GjE+yv4zgbaCNMs/esYSQQfgjnWYA3Vs gmWeHB3ywxhWin1JBcwafbHfRuTiczQYmd0YzTgBVMQuS1wTWdiAl2bUpU9DuYwV v3E2avxmmEoYSnA3Y+3AtbSGH7chbfHnH3a5kujjRAaTKNQOeor4hzTlmNIwEbTM Wlb5pt0+aTHLMKGRYK21wa464VEtpQ0PNT+a1j8QJ2SZv85x0im8zl1XVXrl/Yua FipW6Rsy3JpoMGmqn+ovnDNU49zdojaJctwvEafCOoCtaDHGTaOhhVwg9EXWSRtL KJtoXk1Qm7/BAE3wxxh0BMm9iy6wYjFBrIsPL4tZDZjoemlv3e8BWo1D3kPWgghr dAR4OlLAVJl4i2FuE18zENh+n3BfuwmO4WG/2Zl1gmXZOODZEnk9u18/4AuPASi6 StdJLtnLuxWH4Fw06p9plLGsYc3KQWzocMcqTvDpK9b+8UOcmxCwaB2ZRhRFuNPS wIwBDILv9Ww57C8O2jrDZCjxuue2dNpA/t01M//OHRvE9NZB6o9hSgH2TnrdYjcc Yg00oPaEeNr840Uyl0usiV23JuSYIdgQ1gI2NyJ5Mruvzjzy14cM+GHavqgFYjD5 XilLC3Dyj+ahyvPtUcty0JK/6BG7vdNmiIq58Oksqi3DkevUcR7+jk4LlP5lLxqn GX/nhcVtFmYpx+6Cuqz/HSdXUlcDbh4lYO/1+8XbPZNYEz3AFjPFM0/znb0iAL+z Chx1cEWV4qNN4uMCTjFMpixSUGcs36iWW6CPylXqN7SA76LvLDdnJ4KbhDtUck/p grHHsC4bAuAu29N8vsa/ZROrmvaPpyW3vO/wFFL+dK1UmDGOck3S9gEeyBRV/u6P pKpvtH9UaC+WHlBem1U95spId3fjYZNGXHCUv/SO36HgJ7Kp++3QA45Zp6gRjg== =l0fc –—END PGP MESSAGE–—

Loop for interactive sampling

Suppose we wish to take 100,000 samples with 10 breaks (\(B=10\)). So, we will be taking \(S=1000\) samples per break. First, we will run the sample for only 100 samples and store the final value of the chain as the starting value for our large scale sampler. This isn't strictly necessary.

S = 10000
B = 10

# get decent starting values (and define chain_state)
# !!! Or Load Starting Values as last row in trace below !!!
with model:
    step = pm3.Metropolis()
    start_trace = pm3.sample(100, step=step)
    chain_state = start_trace[-1].copy()
    chain_state_df = pm3.trace_to_dataframe(start_trace).tail(1)
with model:
    for i in range(B):
	startvals = chain_state
	step = pm3.Metropolis()
	trace_m = pm3.sample(S, step=step, start=startvals)
	#
	# for this b, record starting chain_state and trace_mb
	#
	trace_mb = pm3.trace_to_dataframe(trace_m)[['mu','sigma']]
	# connect to database and store table
	connection = engine.connect()
	chain_state_df['timestamp'] = pd.to_datetime('now')
	# note that hash_id allows you to match chain_state to sample it came from
	hash_id = hashlib.md5(str(pd.to_datetime('now')).encode('utf-8')).hexdigest()
	chain_state_df['id'] = hash_id
	trace_mb['id'] = hash_id
	chain_state_df.to_sql("chain_state", connection, if_exists='append', index=False)
	trace_mb.to_sql("trace", connection, if_exists='append', index=False)
	# recover the ending chain state dict for starting sampler above
	chain_state = trace_m[-1].copy()
	chain_state_df = pm3.trace_to_dataframe(trace_m).tail(1)
	connection.close()
# =read_sql= to retrieve data
query1 = """select * from chain_state;"""
query2 = """select * from trace;"""

connection = engine.connect()
resoverall = connection.execute(query1)
chain_state_all = pd.DataFrame(resoverall.fetchall())
chain_state_all.columns = resoverall.keys()

resoverall = connection.execute(query2)
trace = pd.DataFrame(resoverall.fetchall())
trace.columns = resoverall.keys()
connection.close()

# check ids
last_trace = trace.groupby('id')['mu','sigma'].first().reset_index()
# merge first of trace and chain_state_all and look at main model parameters.
table(pd.merge(chain_state_all,last_trace,on='id',how='left').head(10)[['id','timestamp','mu_x','mu_y','sigma_x','sigma_y']],10)

We can of course use the trace for inference:

# set burnin
burnin = 5000
plt.figure()
trace.iloc[burnin:]['mu'].hist(bins=70)
plt.savefig('/tmp/hist.png')
'/tmp/hist.png'

Continue Sampling

Suppose we discover that 100,000 samples isn't enough, and we need a longer chain for convergence reasons. Or suppose we want to interactively investigate the chain as we sample in small chunks. Rather than restart sampling, we can start a new chain were the old one left off and add it to our database. Grab the last element of our full trace (and toss id and date variables):

chain_state_df = chain_state_all.tail(1)
chain_state_df.drop(['id','timestamp'],axis=1,inplace=True)
table(chain_state_df,1)

PyMC3 needs a dict object for start values and it needs to be of a particular form:

  1. Define \(\gamma = \begin{bmatrix} \gamma_0 & \ldots &\gamma_5 & \ldots & \gamma_9 \end{bmatrix}\)
  2. Convert everything to a dict
# construct starting values from df
gamma_labels = ['gamma__' + str(i) for i in range(10)]
gamma_vals = chain_state_df[gamma_labels].values.reshape(10)
chain_state = { 'gamma' : gamma_vals, 'mu' : np.asscalar(chain_state_df.mu.values),
		'sigma' : np.asscalar(chain_state_df.sigma.values)}

With that we should be good to go. We can re-run the loop above and add an additional 100k samples to our chain.

with model:
    for i in range(B):
	startvals = chain_state
	step = pm3.Metropolis()
	trace_m = pm3.sample(S, step=step, start=startvals,progressbar=False)
	#
	# for this b, record starting chain_state and trace_mb
	#
	trace_mb = pm3.trace_to_dataframe(trace_m)[['mu','sigma']]
	# connect to database and store table
	connection = engine.connect()
	chain_state_df['timestamp'] = pd.to_datetime('now')
	# note that hash_id allows you to match chain_state to sample it came from
	hash_id = hashlib.md5(str(pd.to_datetime('now')).encode('utf-8')).hexdigest()
	chain_state_df['id'] = hash_id
	trace_mb['id'] = hash_id
	chain_state_df.to_sql("chain_state", connection, if_exists='append', index=False)
	trace_mb.to_sql("trace", connection, if_exists='append', index=False)
	# recover the ending chain state dict for starting sampler above
	chain_state = trace_m[-1].copy()
	chain_state_df = pm3.trace_to_dataframe(trace_m).tail(1)
	connection.close()