In [ ]:
#@title Imports & Utils
# !pip install -q git+https://github.com/abhijeetgangan/jax-md.git

import jax.numpy as jnp
import numpy as onp
from jax import debug
from jax import jit
from jax import grad
from jax import random
from jax import lax
from jax.config import config
from sampling.correlation_length import ess_corr


config.update('jax_enable_x64', True)
from jax_md import simulate
from jax_md import space
from jax_md import energy
from jax_md import elasticity
from jax_md import quantity
from jax_md import dataclasses
from jax_md.util import f64

# Other libraries
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd 
import jax.numpy as nnp
import seaborn as sns
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import numpy.linalg as npl
# import seaborn as sns
import matplotlib.pyplot as plt 
from plotting import rama_plot
from jax_md.simulate import Sampler
# from sampling.sampler import Sampler
from jax_md import space, quantity
from jax_md import simulate, energy
import math
import jax
import scipy
import jax.numpy as nnp

# LAMMPS data
data_lammps = pd.read_csv("lammps_nvt.dat", delim_whitespace="  ", header=None)
data_lammps = data_lammps.dropna(axis=1)
data_lammps.columns = ['Time','T','P','V','E','H']
t_l, T, P, V, E, H = data_lammps['Time'], data_lammps['T'], data_lammps['P'], data_lammps['V'], data_lammps['E'], data_lammps['H']

def run(step_size, temp, chain_length):


  print(len(E))
  # sns.lineplot(x=nnp.arange(0,chain_length,100),y=data_lammps['E'][:100])
  # plt.savefig("si_plot")

  # raise Exception

  lammps_step_0 = onp.loadtxt('step_1.traj', dtype=f64)

  # Load positions from lammps
  positions = jnp.array(lammps_step_0[:,2:5], dtype=f64)
  # Load velocities from lammps
  velocity = jnp.array(lammps_step_0[:,5:8], dtype=f64)
  latvec = jnp.array([[21.724, 0.000000, 0.000000], [0.00000, 21.724, 0.00000],[0.00000, 0.0000, 21.724]])

  # Import unit system
  from jax_md import units

  # Metal units
  unit = units.metal_unit_system()

  # Simulation parameters
  timestep = 1e-3
  fs = timestep * unit['time']
  ps = unit['time']
  dt = fs
  write_every = 1
  box = latvec
  T_init = temp * unit['temperature']
  Mass = 28.0855 * unit['mass']
  key = random.PRNGKey(121)
  steps = chain_length

  # Logger to save data
  log = {
  'E': jnp.zeros((steps // write_every,)),
  'P': jnp.zeros((steps // write_every,)),
  'T': jnp.zeros((steps // write_every,)),
  'kT': jnp.zeros((steps // write_every,)),
  'state' : jnp.zeros((steps // write_every, 512, 3)),
  }

  # Setup the periodic boundary conditions. 
  displacement, shift = space.periodic_general(latvec) 
  dist_fun = space.metric(displacement) 
  neighbor_fn, energy_fn = energy.stillinger_weber_neighbor_list(displacement, latvec, disable_cell_list=True)
  energy_fn = jit(energy_fn)

  # Extra capacity to prevent overflow
  nbrs = neighbor_fn.allocate(positions, box=box, extra_capacity=2)


  # NVT simulation
  init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift, dt=dt, kT=T_init, tau=100 * dt, chain_length=3, chain_steps=1, sy_steps=1)
  apply_fn = jit(apply_fn)
  state = init_fn(key, positions, box=box, neighbor=nbrs, kT=T_init, mass=Mass)

  # Restart from LAMMPS velocities
  state = dataclasses.replace(state, momentum = Mass * velocity * unit['velocity'])


  


  # lammps_boltzmann =  0.001987191

  to_unit_cell = lambda x : x # jax.numpy.mod(x, 21.724)

  nlogp = lambda nbrs: lambda x : energy_fn(nnp.reshape(to_unit_cell(x), (state.position.shape)), neighbor=nbrs, box=box) / T_init

  value_grad = lambda nbrs : jax.value_and_grad(nlogp(nbrs))



  class MD():


    def __init__(self, d, nbrs):
      self.d = d
      self.nbrs = nbrs

    def grad_nlogp(self, x):
      return value_grad(self.nbrs)(x)

    def transform(self, x):
      # return x 
      out1 = nlogp(self.nbrs)(x) * T_init
      out2 = x
      # jax.debug.print(out.item())
      return out1, out2

    def prior_draw(self, key):

      return to_unit_cell(nnp.reshape(state.position, math.prod(state.position.shape)))
      # return nnp.reshape(state_r.position, math.prod(state_r.position.shape))

  eps_in_si = step_size*scipy.constants.femto * nnp.sqrt(3 * 512 * scipy.constants.k * temp)
  si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
  eps_val = eps_in_si * si_to_gmol


  print("eps val", eps_val)

  target = MD(d = math.prod(state.position.shape), nbrs=nbrs)
  # nnp.zeros(512*3)+Mass,
  # eps=eps_val, L=temp*eps_val,
  sampler = Sampler(target,  shift_fn = shift, eps=eps_val, L=temp*eps_val,
    masses = nnp.zeros(512*3)+Mass,
    frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0)
                    # masses = jnp.ones(512*3))

  num_chains = 1
  output, T, V, L, eps = sampler.sample(chain_length, num_chains, output= 'detailed',
                                  )

  ener, samples = output

  # raise Exception

  print("Results\n\n\n")
  print(f"max dE {nnp.max(T+V)}")


  # print(energy_fn(state_r.position, neighbor=nbrs_r, box=box))
  # print(target.grad_nlogp(nnp.reshape(state_r.position, math.prod(state_r.position.shape)))[1]*(temp * lammps_boltzmann))

  # print(f"Nose-Hoover mean of Potential is {log_r['E'].mean()} and variance is {log_r['E'].var()}")
  print(f"MCLMC mean of Potential is {ener.mean()} and variance is {ener.var()}")

  # print(f"Nose-Hoover ESS is {ess_corr(log_r['state'])}")
  print(f"MCLMC ESS is {ess_corr(samples)}")

  print(f"MCLMC energy error is {(nnp.square(T+V)/(512*3)).mean()}")
  # print(f"Nose-Hoover energy error is {(nnp.square(log_r['E'][1:]-log_r['E'][:-1])/(512*3)).mean()}")
  # samples = nnp.clip(samples, a_max=-1000.0)




  print(samples.shape, "shape")
  data = pd.DataFrame(data = ener, columns=['E'])
  # sns.lineplot(x=nnp.arange(0,chain_length),y=log_r['E'][:chain_length])
  sns.lineplot(data=data,y='E', x=nnp.arange(chain_length))

      # # sns.histplot(data=e_data,x='E', ax=axs[3][1])
      


  # plt.savefig("situning_plot")


  print(f"unit temp {unit['temperature']}")

  return samples, ener, T,V
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [ ]:
samples, ener, T, V = run(step_size=2, temp=300, chain_length=10000)
501
/global/homes/r/reubenh/.conda/envs/py38/lib/python3.8/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
eps val 1.2379506249354235
/global/homes/r/reubenh/.conda/envs/py38/lib/python3.8/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
Results



max dE 0.012758478636392567
MCLMC mean of Potential is -2200.3206001787125 and variance is 0.9918181413749773
MCLMC ESS is 0.0007601470059473381
MCLMC energy error is 2.0043163946486407e-09
(10000, 1536) shape
unit temp 8.617330337217213e-05
In [ ]:
m = -2200.420922954092
sns.lineplot([abs(ener[:i].mean().item() - m) for i in range(1,1000)]
    )
Out[ ]:
<Axes: >