Speeding up PyMC3 NUTS Sampler
I’m trying to use the NUTS sampler in PyMC3
However, it was running at 2 iterations per second on my model, while the Metropolis Hastings sampler ran 450x faster.
I showed my example to some of the PyMC3 devs on Twitter, and Thomas Wiecki showed me this trick:
@tdhopper @Springcoil You need pm.NUTS(scaling=np.power(model.dict_to_array(v_params.stds), 2), is_cov=True) (terrible syntax, I know).
— Thomas Wiecki (@twiecki) November 8, 2016
It resulted in a 25x speedup of the NUTS sampler. The code looks like this
with pm.Model() as model:
# SETUP MODEL HERE
mu, sds, elbo = pm.variational.advi(n=200000)
step = pm.NUTS(scaling=np.power(model.dict_to_array(sds), 2))
return pm.sample(niter,
step=step,
is_cov=True,
start=mu,
)