Implementation

There is a Python implementation of the code in Blackjax here.

For example, this is the momentum update, calculated in an efficient way:

def update_momentum(d, eps):

  def update(u, g):
      g_norm = jnp.sqrt(jnp.sum(jnp.square(g)))
      e = - g / g_norm
      ue = jnp.dot(u, e)
      delta = eps * g_norm / (d-1)
      zeta = jnp.exp(-delta)
      uu = e *(1-zeta)*(1+zeta + ue * (1-zeta)) + 2*zeta* u
      delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1-ue)*zeta**2)
      return uu/jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r

  return update

Recall from the tutorial that the momentum update should be:

\[ u \mapsto \frac{u + (\sinh{(\delta)}+ {e} \cdot u (\cosh (\delta) -1)) {e} }{\cosh{(\delta)} + {e} \cdot u \sinh{(\delta)}} \]

where \(\delta = \epsilon \vert \nabla E(x) \vert / d\) and \({e} = - \nabla E(x) / \vert \nabla E(x) \vert\).

The above code gives this result.

The leapfrog integrator described in the supplementary material can be defined like so (with some extra values passed along like the change in log likelihood):

def leapfrog(d, T, V):

  def step(x, u, g, eps, sigma):

    # V T V
    uu, r1 = V(eps * 0.5, u, g * sigma)
    xx, l, gg = T(eps, x, uu*sigma)
    uu, r2 = V(eps * 0.5, uu, gg * sigma)

    # kinetic energy change
    kinetic_change = (r1 + r2) * (d-1)

    return xx, uu, l, gg, kinetic_change

  return step, 1 # number of gradient calls per step