Speeding up PyMC3 NUTS Sampler

I’m trying to use the NUTS sampler in PyMC3

However, it was running at 2 iterations per second on my model, while the Metropolis Hastings sampler ran 450x faster.

I showed my example to some of the PyMC3 devs on Twitter, and Thomas Wiecki showed me this trick:

It resulted in a 25x speedup of the NUTS sampler. The code looks like this

with pm.Model() as model:
    # SETUP MODEL HERE
    mu, sds, elbo = pm.variational.advi(n=200000)
    step = pm.NUTS(scaling=np.power(model.dict_to_array(sds), 2))
    return pm.sample(niter,
                     step=step,
                     is_cov=True,
                     start=mu,
                     )