from moleculekit.molecule import Molecule
import jax_md
import os
from parameters import Parameters, set_box, set_positions
import jax.numpy as nnp
from torchmd.forcefields.forcefield import ForceField
import numpy.linalg as npl
import seaborn as sns
import matplotlib.pyplot as plt
from plotting import rama_plot
# from jax_md.simulate import Sampler
from sampling.sampler import Sampler
from jax_md import space, quantity
from jax_md import simulate, energy
from jax_md.simulate import ess_corr
import math
import jax
import scipy
import mdtraj as md
import pandas as pd
# load alanine dipeptide
testdir = "data/prod_alanine_dipeptide_amber/"
mol = Molecule(os.path.join(testdir, "structure.prmtop")) # Reading the system topology
mol.read(os.path.join(testdir, "input.coor")) # Reading the initial simulation coordinates
mol.read(os.path.join(testdir, "input.xsc")) # Reading the box dimensions
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
cpu
ff = ForceField.create(mol, os.path.join(testdir, "structure.prmtop"))
parameters = Parameters(ff, mol, precision=float, device='cpu')
nreplicas = 1 # don't change
pos = set_positions(nreplicas, mol.coords)
box = nnp.array(set_box(nreplicas, mol.box), dtype='float32')
from forces import Forces
forces = Forces(parameters, cutoff=9, rfa=True, switch_dist=7.5, terms=["bonds", "angles", "dihedrals", "impropers", "1-4", "electrostatics", "lj"])
forces.compute(pos, box)
Array(-2190.62312837, dtype=float64)
Calculating the potential should result in the same value, -2190.6, as the original TorchMD code at https://github.com/torchmd/torchmd/blob/master/examples/tutorial.ipynb:
psi_indices, phi_indices = [6, 8, 14, 16], [4, 6, 8, 14]
from sampling.dynamics import hamiltonian_dynamics
from sampling.old_sampler import Sampler as OldSampler
from jax_md.simulate import Sampler as OriginalSampler
def run_mclmc(T, dt, L_factor, chain_length, gamma, masses = nnp.ones(2064,), new=True):
print(f'Comparison of MCLMC and Langevin, at T={T}, dt={dt}, and with {chain_length} steps ')
fig, axs = plt.subplots(4,2, figsize=(8, 16))
axs[0,0].set_xlim(-math.pi, math.pi)
axs[0,0].set_ylim(-math.pi, math.pi)
axs[0,1].set_xlim(-math.pi, math.pi)
axs[0,1].set_ylim(-math.pi, math.pi)
axs[0,0].title.set_text('MCLMC')
axs[0,1].title.set_text('Langevin')
fig.tight_layout(pad=3.0)
BOLTZMAN = 0.001987191
a = jax.random.normal(key=jax.random.PRNGKey(26), shape=(2064,10000))
# sampler = Sampler(target, shift_fn=shift_fn, masses = jax.numpy.tile(mol.masses,3), frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0, L = L_factor*eps,
# eps=eps)
if new:
# nlogp = lambda x : nnp.dot(x,a).mean()
nlogp = (lambda x : forces.compute(nnp.reshape(x, pos.shape), box))
to_unit_cell = lambda x : jax.numpy.mod(x, box[0][0][0])
energy_fn = lambda x : nlogp((to_unit_cell(x))) / (BOLTZMAN * T)
value_grad = jax.value_and_grad(energy_fn)
class MD():
def __init__(self, d):
self.d = d
self.nbrs = None
def grad_nlogp(self, x):
return value_grad(x)
def transform(self, x):
return to_unit_cell(x)
def prior_draw(self, key):
return to_unit_cell(nnp.array(nnp.reshape(pos, math.prod(pos.shape)), dtype='float64'))
# displacement_fn, shift_fn = space.periodic(box[0][0][0].item())
eps_in_si = dt*scipy.constants.femto * nnp.sqrt(3 * 688 * scipy.constants.k * T)
si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
eps = eps_in_si * si_to_gmol
target = MD(d = math.prod(pos.shape))
sampler = Sampler(target, frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0, L = L_factor*eps,
eps=eps)
sampler.sigma = 1/nnp.sqrt(masses)
sampler.hamiltonian_dynamics = hamiltonian_dynamics(integrator=sampler.integrator, sigma=sampler.sigma, grad_nlogp=sampler.Target.grad_nlogp, d=sampler.Target.d)
else:
# nlogp = lambda x : nnp.dot(x,a).mean()
to_unit_cell = lambda x : jax.numpy.mod(x, box[0][0][0])
nlogp = (lambda x : forces.compute(nnp.reshape(x, pos.shape), box))
energy_fn = lambda x : nlogp(x) / (BOLTZMAN * T)
value_grad = jax.value_and_grad(energy_fn)
class MD():
def __init__(self, d):
self.d = d
self.nbrs = None
def grad_nlogp(self, x):
return value_grad(x)
def transform(self, x):
return x
def prior_draw(self, key):
return to_unit_cell(nnp.array(nnp.reshape(pos, math.prod(pos.shape)), dtype='float64'))
displacement_fn, shift_fn = space.periodic(box[0][0][0].item())
eps_in_si = dt*scipy.constants.femto * nnp.sqrt(3 * 688 * scipy.constants.k * T)
si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
eps = eps_in_si * si_to_gmol
target = MD(d = math.prod(pos.shape))
sampler = OriginalSampler(target, shift_fn=shift_fn, masses=nnp.ones(2064,), frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0, L = L_factor*eps,
eps=eps)
# sampler.sigma = 1/nnp.sqrt(masses)
### integrator ###
## NOTE: sigma does not arise from any tuning here: it is a fixed parameter
# sampler.hamiltonian_dynamics = hamiltonian_dynamics(integrator=sampler.integrator, sigma=sampler.sigma, grad_nlogp=sampler.Target.grad_nlogp, d=sampler.Target.d)
num_chains = 1
samples, K, V, L, _ = sampler.sample(chain_length, num_chains, output= 'detailed')
de = K + V
subsampled = samples[::10, :]
trajectory = md.load('./data/prod_alanine_dipeptide_amber/structure.pdb')
trajectory.xyz=nnp.array(nnp.reshape(subsampled, (subsampled.shape[0], 688, 3)))[::1]
unitC = nnp.array([(subsampled.shape[0])*[nnp.diag(mol.box[:,0])]]).squeeze()
trajectory.unitcell_vectors = unitC # traj.unitcell_vectors[:10000]
angles = md.compute_dihedrals(trajectory, [phi_indices, psi_indices])
sns.scatterplot(x = angles[:, 0], y = angles[:, 1], ax=axs[0,0], s=1)
print("MCLMC\n\n")
print("L: ", L)
print("eps: ", eps)
sigma = 1/(nnp.sqrt(masses))
rmses = nnp.sqrt(nnp.mean((samples/sigma)**2, axis=0))
print("Mean RMS: ", rmses.mean())
print("Max RMS: ", nnp.max(rmses))
print("Min RMS: ", nnp.min(rmses))
print("Condition number: ", nnp.max(rmses)/nnp.min(rmses))
print("Energy error: ", (nnp.square(de)/math.prod(pos.shape)).mean())
print("ESS (via ess_corr): ", ess_corr(samples))
num_bins = 100
cumulative_sum = lambda x : jax.lax.scan(lambda accum, inp : (accum+inp, accum+inp), 0, x)[1]
t_data = pd.DataFrame(data = cumulative_sum(K), columns=['K'])
sns.lineplot(data=t_data,y='K', x=nnp.arange(chain_length), ax=axs[1][0])
sns.histplot(data=t_data,x='K', ax=axs[1][1], bins=num_bins)
v_data = pd.DataFrame(data = cumulative_sum(V), columns=['V'])
sns.lineplot(data=v_data,y='V', x=nnp.arange(chain_length), ax=axs[2][0])
sns.histplot(data=v_data,x='V', ax=axs[2][1], bins=num_bins)
e_data = pd.DataFrame(data = cumulative_sum(K)+cumulative_sum(V), columns=['E'])
sns.lineplot(data=e_data,y='E', x=nnp.arange(chain_length), ax=axs[3][0])
sns.histplot(data=e_data,x='E', ax=axs[3][1], bins=num_bins)
# name = 'mclmc' + str(eps) + str(L) + str(num_chains)
# trajectory.save_pdb('./data/prod_alanine_dipeptide_amber/traj'+name+'.pdb')
# ##### LANGEVIN
# key = jax.random.PRNGKey(0)
# dt = dt * 1e-3
# init, update = simulate.nvt_langevin(nlogp, shift_fn, dt, kT=BOLTZMAN*T, gamma=gamma)
# state = init(key, pos)
# samples_langevin = []
# for i in range(chain_length):
# if i%10==0:
# samples_langevin.append(state.position)
# state = update(state)
# trajectory2 = md.load('./data/prod_alanine_dipeptide_amber/structure.pdb')
# trajectory2.xyz=nnp.array(nnp.array(samples_langevin).squeeze())[::1]
# unitC = nnp.array([(len(samples_langevin))*[nnp.diag(mol.box[:,0])]]).squeeze()
# trajectory2.unitcell_vectors = unitC
# angles_langevin = md.compute_dihedrals(trajectory2, [phi_indices, psi_indices])
# sns.scatterplot(x = angles_langevin[:, 0], y = angles_langevin[:, 1], ax=axs[1], s=5)
# # name2 = 'mclmc' + str(dt) + str(gamma) + str(num_chains)
# # trajectory2.save_pdb('./data/prod_alanine_dipeptide_amber/traj'+name2+'.pdb')
# print("Langevin\n\n")
# print("Gamma: ", gamma)
return samples,K, V, L, eps
# , samples_langevin, gamma
x_initial_new,T_new, V_new,L,eps = run_mclmc(T=300, dt=1.0, L_factor=300, chain_length=10000, gamma=0.1, new = True)
Comparison of MCLMC and Langevin, at T=300, dt=1.0, and with 10000 steps MCLMC L: 215.25528397334264 eps: 0.7175176132444755 Mean RMS: 10.392512210737479 Max RMS: 18.991122677905587 Min RMS: 0.9761274676448282 Condition number: 19.45557655879392 Energy error: 0.1328809186856266 ESS (via ess_corr): 0.00048483604479146534
# from jax_md.old_annealing_sampler import Sampler as AnnealingSampler
from sampling.annealing import Sampler as AnnealingSampler
from sampling.dynamics import resample_particles, systematic_resampling
def run_annealing(chain_length, num_chains):
BOLTZMAN = 0.001987191
nlogp = (lambda x : forces.compute(nnp.reshape(x, pos.shape), box))
to_unit_cell = lambda x : jax.numpy.mod(x, box[0][0][0])
energy_fn = lambda x : nlogp((to_unit_cell(x)))
value_grad = jax.value_and_grad(energy_fn)
class MD():
def __init__(self, d):
self.d = d
self.nbrs = None
def grad_nlogp(self, x):
return value_grad(x)
def transform(self, x):
return to_unit_cell(x)
def prior_draw(self, key):
return to_unit_cell(nnp.array(nnp.reshape(pos, math.prod(pos.shape)), dtype='float64'))
# displacement_fn, shift_fn = space.periodic(box[0][0][0].item())
# eps_in_si = dt*scipy.constants.femto * nnp.sqrt(3 * 688 * scipy.constants.k * T)
# si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
# eps = eps_in_si * si_to_gmol
target = MD(d = math.prod(pos.shape))
# sampler = AnnealingSampler(target, shift_fn=shift_fn) # , masses=jax.numpy.tile(mol.masses,3))
# sampler = AnnealingSampler(target, shift_fn=shift_fn, masses=jax.numpy.tile(mol.masses,3))
sampler = AnnealingSampler(target)
def temp_func(T,Tprev, L, eps):
dt = 2
eps_in_si = dt*scipy.constants.femto * nnp.sqrt(3 * 688 * scipy.constants.k * (T/BOLTZMAN))
si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
eps = eps_in_si * si_to_gmol
return eps*30, eps
sampler.temp_func = temp_func
def resample_particles(logw, x, u, l, g, key, L, eps, T):
indices, key = systematic_resampling(logw*0.01, key)
jax.debug.print("indices {}", indices)
jax.debug.print("logw {}", logw)
# indices = nnp.array([0,0,0,0,0,0,0,0,0,0])
x_resampled = nnp.take(x, indices, axis=0)
u_resampled = nnp.take(u, indices, axis=0)
l_resampled = nnp.take(l, indices)
g_resampled = nnp.take(g, indices, axis=0)
return (x_resampled, u_resampled, l_resampled, g_resampled, key, L, eps, T)
# no resampling
sampler.resample_particles = lambda logw, x, u, l, g, key, L, eps, T : (x, u, l, g, key, L, eps, T)
# dt = 2
# T = 300
# eps_in_si = dt*scipy.constants.femto * nnp.sqrt(3 * 688 * scipy.constants.k * T)
# si_to_gmol = nnp.sqrt(1000*scipy.constants.Avogadro)/scipy.constants.angstrom
# eps = eps_in_si * si_to_gmol
# sampler.L = 30*eps
# sampler.eps_initial = eps
# masses = jax.numpy.tile(mol.masses,3), frac_tune1=0.0, frac_tune2=0.0, frac_tune3=0.0, L = L_factor*eps,
# eps=eps)
samples, energy = sampler.sample(steps_at_each_temp=chain_length, temp_schedule=[4000.0*BOLTZMAN, 370*BOLTZMAN, 300*BOLTZMAN], num_chains=num_chains, tune_steps=0, random_key=jax.random.PRNGKey(42))
# , x_initial=x_initial[::50])
print(samples.shape, "shape\n\n\n")
print(energy.shape)
def pl(i):
subsampled = nnp.reshape(samples[i], (chain_length*num_chains, 2064))
print(subsampled.shape, "shape")
trajectory = md.load('./data/prod_alanine_dipeptide_amber/structure.pdb')
trajectory.xyz=nnp.array(nnp.reshape(subsampled, (subsampled.shape[0], 688, 3)))[::1]
unitC = nnp.array([(subsampled.shape[0])*[nnp.diag(mol.box[:,0])]]).squeeze()
trajectory.unitcell_vectors = unitC # traj.unitcell_vectors[:10000]
angles = md.compute_dihedrals(trajectory, [phi_indices, psi_indices])
sns.scatterplot(x = angles[:, 0], y = angles[:, 1], s=1)
plt.xlim(-math.pi, math.pi)
plt.ylim(-math.pi, math.pi)
pl(0)
pl(1)
pl(2)
return samples, energy
samples, energy = run_annealing(chain_length=1000, num_chains=30)
eps: 5.24000776241689, L: 157.20023287250672, T: 7.948764000000001 eps: 1.5936861437850638, L: 47.810584313551914, T: 7.948764000000001 eps: 1.435035226488951, L: 43.05105679466853, T: 0.7352606700000001 (3, 1000, 30, 2064) shape (3, 1000, 30) (30000, 2064) shape (30000, 2064) shape (30000, 2064) shape