#@title Imports & Utils
# !pip install -q git+https://github.com/abhijeetgangan/jax-md.git
import jax.numpy as jnp
import numpy as onp
from jax import debug
from jax import jit
from jax import grad
from jax import random
from jax import lax
from jax.config import config
from sampling.correlation_length import ess_corr
config.update('jax_enable_x64', True)
from jax_md import simulate
from jax_md import space
from jax_md import energy
from jax_md import elasticity
from jax_md import quantity
from jax_md import dataclasses
from jax_md.util import f64
# Other libraries
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as nnp
import seaborn as sns
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy.linalg as npl
# import seaborn as sns
import matplotlib.pyplot as plt
from plotting import rama_plot
from jax_md.simulate import Sampler
# from sampling.sampler import Sampler
from jax_md import space, quantity
from jax_md import simulate, energy
import math
import jax
import scipy
import jax.numpy as nnp
# LAMMPS data
data_lammps = pd.read_csv("lammps_nvt.dat", delim_whitespace=" ", header=None)
data_lammps = data_lammps.dropna(axis=1)
data_lammps.columns = ['Time','T','P','V','E','H']
t_l, T, P, V, E, H = data_lammps['Time'], data_lammps['T'], data_lammps['P'], data_lammps['V'], data_lammps['E'], data_lammps['H']
def run(step_size, temp, chain_length):
print(len(E))
# sns.lineplot(x=nnp.arange(0,chain_length,100),y=data_lammps['E'][:100])
# plt.savefig("si_plot")
# raise Exception
lammps_step_0 = onp.loadtxt('step_1.traj', dtype=f64)
# Load positions from lammps
positions = jnp.array(lammps_step_0[:,2:5], dtype=f64)
# Load velocities from lammps
velocity = jnp.array(lammps_step_0[:,5:8], dtype=f64)
latvec = jnp.array([[21.724, 0.000000, 0.000000], [0.00000, 21.724, 0.00000],[0.00000, 0.0000, 21.724]])
# Import unit system
from jax_md import units
# Metal units
unit = units.metal_unit_system()
# Simulation parameters
timestep = 1e-3
fs = timestep * unit['time']
ps = unit['time']
dt = fs
write_every = 1
box = latvec
T_init = temp * unit['temperature']
Mass = 28.0855 * unit['mass']
key = random.PRNGKey(121)
steps = chain_length
# Logger to save data
log = {
'E': jnp.zeros((steps // write_every,)),
'P': jnp.zeros((steps // write_every,)),
'T': jnp.zeros((steps // write_every,)),
'kT': jnp.zeros((steps // write_every,)),
'state' : jnp.zeros((steps // write_every, 512, 3)),
}
# Setup the periodic boundary conditions.
displacement, shift = space.periodic_general(latvec)
dist_fun = space.metric(displacement)
neighbor_fn, energy_fn = energy.stillinger_weber_neighbor_list(displacement, latvec, disable_cell_list=True)
energy_fn = jit(energy_fn)
# Extra capacity to prevent overflow
nbrs = neighbor_fn.allocate(positions, box=box, extra_capacity=2)
# NVT simulation
init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift, dt=dt, kT=T_init, tau=100 * dt, chain_length=3, chain_steps=1, sy_steps=1)
apply_fn = jit(apply_fn)
state = init_fn(key, positions, box=box, neighbor=nbrs, kT=T_init, mass=Mass)
# Restart from LAMMPS velocities
state = dataclasses.replace(state, momentum = Mass * velocity * unit['velocity'])
# lammps_boltzmann = 0.001987191
to_unit_cell = lambda x : x # jax.numpy.mod(x, 21.724)
nlogp = lambda nbrs: lambda x : energy_fn(nnp.reshape(to_unit_cell(x), (state.position.shape)), neighbor=nbrs, box=box) / T_init
value_grad = lambda nbrs : jax.value_and_grad(nlogp(nbrs))
class MD():
def __init__(self, d, nbrs):
self.d = d
self.nbrs = nbrs
def grad_nlogp(self, x):
return value_grad(self.nbrs)(x)
def transform(self, x):
# return x
out1 = nlogp(self.nbrs)(x) * T_init
out2 = x
# jax.debug.print(out.item())
return out1, out2
def prior_draw(self, key):
return to_unit_cell(nnp.reshape(state.position, math.prod(state.position.shape)))
# return nnp.reshape(state_r.position, math.prod(state_r.position.shape))
eps_in_si = step_size*scipy.constants.femto * nnp.sqrt(3 * 512 * scipy.constants.k * temp)
si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
eps_val = eps_in_si * si_to_gmol
print("eps val", eps_val)
target = MD(d = math.prod(state.position.shape), nbrs=nbrs)
# nnp.zeros(512*3)+Mass,
# eps=eps_val, L=temp*eps_val,
sampler = Sampler(target, shift_fn = shift, eps=eps_val, L=temp*eps_val,
masses = nnp.zeros(512*3)+Mass,
frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0)
# masses = jnp.ones(512*3))
num_chains = 1
output, T, V, L, eps = sampler.sample(chain_length, num_chains, output= 'detailed',
)
ener, samples = output
# raise Exception
print("Results\n\n\n")
print(f"max dE {nnp.max(T+V)}")
# print(energy_fn(state_r.position, neighbor=nbrs_r, box=box))
# print(target.grad_nlogp(nnp.reshape(state_r.position, math.prod(state_r.position.shape)))[1]*(temp * lammps_boltzmann))
# print(f"Nose-Hoover mean of Potential is {log_r['E'].mean()} and variance is {log_r['E'].var()}")
print(f"MCLMC mean of Potential is {ener.mean()} and variance is {ener.var()}")
# print(f"Nose-Hoover ESS is {ess_corr(log_r['state'])}")
print(f"MCLMC ESS is {ess_corr(samples)}")
print(f"MCLMC energy error is {(nnp.square(T+V)/(512*3)).mean()}")
# print(f"Nose-Hoover energy error is {(nnp.square(log_r['E'][1:]-log_r['E'][:-1])/(512*3)).mean()}")
# samples = nnp.clip(samples, a_max=-1000.0)
print(samples.shape, "shape")
data = pd.DataFrame(data = ener, columns=['E'])
# sns.lineplot(x=nnp.arange(0,chain_length),y=log_r['E'][:chain_length])
sns.lineplot(data=data,y='E', x=nnp.arange(chain_length))
# # sns.histplot(data=e_data,x='E', ax=axs[3][1])
# plt.savefig("situning_plot")
print(f"unit temp {unit['temperature']}")
return samples, ener, T,V