import os
import matplotlib.pyplot as pyplot
import numpy as np
from openmm import unit
from openmmtools.multistate import (MultiStateReporter, MultiStateSampler,
ReplicaExchangeAnalyzer,
ReplicaExchangeSampler)
import physical_validation as pv
from physical_validation.data.ensemble_data import EnsembleData
from physical_validation.data.observable_data import ObservableData
from physical_validation.data.simulation_data import SimulationData
from physical_validation.data.unit_data import UnitData
from physical_validation.util.error import InputError
kB = unit.MOLAR_GAS_CONSTANT_R # Boltzmann constant
[docs]def physical_validation_ensemble(
output_data="output.nc", output_directory="ouput", plotfile='ensemble_check',
pairs='single', ref_state_index=0, frame_start=0, frame_stride=1, frame_end=-1,
bootstrap_error=False, bootstrap_repetitions=200, bootstrap_seed=None, data_is_uncorrelated=False,
):
"""
Run ensemble physical validation test for 2 states in replica exchange simulation
:param output_data: Path to the output data for a NetCDF-formatted file containing replica exchange simulation data
:type output_data: str
:param plotfile: Filename for outputting ensemble check plot
:type plotfile: str
:param pairs: Option for running ensemble validation on all replica pair combinations ('all'), adjacent pairs ('adjacent'), or single pair with optimal spacing ('single')
:param ref_state_index: Index in temperature_list to use as one of the states in the ensemble check. The other state will be chosen based on the energy standard deviation at the reference state. Ignored if pairs='all'
:type ref_state_index: int
:param frame_start: Set the starting frame for the analysis (default=0)
:type frame_start: int
:param frame_stride: Set the sample spacing for decorrelated frames (default=1)
:type frame_stride: int
:param frame_end: Set the end frame for the analysis (default=-1)
:type frame_end: int
:param bootstrap_error: Flag indicating if standard error of energies should be computed via bootstrapping (default=False)
:type bootstrap_error: bool
:param bootstrap_repetitions: Number of bootstrap repititions (default=200)
:type bootstrap_repetitions: int
:param bootstrap_seed: Set a bootstrapping seed for reproducible results. If None, a random seed is used. (default=None)
:type bootstrap_seed: int
:param data_is_uncorrelated: Flag indicating if the data after applying frame_start, frame_stride, and frame_end is uncorrelated. Applying the appropriate frame slicing greatly speeds up the ensemble check. (default=False)
:type data_is_uncorrelated: bool
"""
# Get temperature list and read the energies for individual temperature replicas
reporter = MultiStateReporter(output_data, open_mode="r")
analyzer = ReplicaExchangeAnalyzer(reporter)
states = reporter.read_thermodynamic_states()[0]
temperature_list = []
for s in states:
temperature_list.append(s.temperature)
(
replica_energies_all,
unsampled_state_energies,
neighborhoods,
replica_state_indices_all,
) = analyzer.read_energies()
# Select frames to analyze
if frame_end > 0:
replica_energies = replica_energies_all[:,:,frame_start:frame_end:frame_stride]
replica_state_indices = replica_state_indices_all[:,frame_start:frame_end:frame_stride]
else:
replica_energies = replica_energies_all[:,:,frame_start::frame_stride]
replica_state_indices = replica_state_indices_all[:,frame_start::frame_stride]
del replica_energies_all, replica_state_indices_all
n_particles = np.shape(reporter.read_sampler_states(iteration=0)[0].positions)[0]
# Close the data file:
reporter.close()
T_unit = temperature_list[0].unit
temps = np.array([temp.value_in_unit(T_unit) for temp in temperature_list])
beta_k = 1 / (kB.value_in_unit(unit.kilojoule_per_mole/T_unit) * temps)
n_replicas = len(temperature_list)
for k in range(n_replicas):
replica_energies[:, k, :] *= beta_k[k] ** (-1)
total_steps = len(replica_energies[0][0])
state_energies = np.zeros([n_replicas, total_steps])
for step in range(total_steps):
for state in range(n_replicas):
state_energies[state, step] = replica_energies[
np.where(replica_state_indices[:, step] == state)[0], 0, step
]
state_energies *= unit.kilojoule_per_mole
T_array = np.zeros(len(temperature_list))
for i in range(len(temperature_list)):
T_array[i] = temperature_list[i].value_in_unit(T_unit)
if pairs.lower() != 'single' and pairs.lower() != 'adjacent' and pairs.lower() != 'all':
print(f"Error: Pair option '{pairs}' not recognized, using default option 'single'")
pairs = 'single'
if pairs.lower() == 'single':
# Run ensemble validation on one optimally spaced temperature pair
quantiles = {}
# Find optimal state pair for ensemble check:
# Compute standard deviations of each energy distribution:
state_energies_std = np.std(state_energies,axis=1)
# Select reference state point
T_ref = temperature_list[ref_state_index]
std_ref = state_energies_std[ref_state_index]
# Compute optimal spacing:
deltaT = 2*kB*T_ref**2/std_ref
#print("DeltaT: %r" %deltaT)
# Find closest match
T_diff = np.abs(T_ref.value_in_unit(T_unit)-T_array)
T_opt_index = np.argmin(np.abs(deltaT.value_in_unit(T_unit) - T_diff))
T_opt = temperature_list[T_opt_index]
# Set SimulationData for physical validation
state1_index = ref_state_index
state2_index = T_opt_index
sim_data1, sim_data2 = set_simulation_data(
state_energies,
T_array,
state1_index,
state2_index
)
# Run physical validation
try:
quantiles_ij = pv.ensemble.check(
sim_data1,
sim_data2,
total_energy=False,
filename=plotfile,
bootstrap_error=bootstrap_error,
bootstrap_repetitions=bootstrap_repetitions,
bootstrap_seed=bootstrap_seed,
data_is_uncorrelated=data_is_uncorrelated,
)
quantiles[f"state{state1_index}_state{state2_index}"] = quantiles_ij[0]
except InputError:
print(f"Insufficient overlap between trajectories for states {state1_index} and {state2_index}. Skipping...")
elif pairs.lower() == 'adjacent':
# Run ensemble validation on all adjacent temperature pairs
quantiles = {}
for i in range(len(temperature_list)-1):
# Set SimulationData for physical validation
state1_index = i
state2_index = i+1
sim_data1, sim_data2 = set_simulation_data(
state_energies,
T_array,
state1_index,
state2_index
)
# Run physical validation
try:
quantiles_ij = pv.ensemble.check(
sim_data1,
sim_data2,
total_energy=False,
filename=f"{plotfile}_{state1_index}_{state2_index}",
bootstrap_error=bootstrap_error,
bootstrap_repetitions=bootstrap_repetitions,
bootstrap_seed=bootstrap_seed,
data_is_uncorrelated=data_is_uncorrelated,
)
quantiles[f"state{state1_index}_state{state2_index}"] = quantiles_ij[0]
except InputError:
print(f"Insufficient overlap between trajectories for states {state1_index} and {state2_index}. Skipping...")
elif pairs.lower() == 'all':
# Run ensemble validation on all combinations of temperature pairs
quantiles = {}
for i in range(len(temperature_list)):
for j in range(i+1,len(temperature_list)):
# Set SimulationData for physical validation
state1_index = i
state2_index = j
sim_data1, sim_data2 = set_simulation_data(
state_energies,
T_array,
state1_index,
state2_index
)
# Run physical validation
try:
quantiles_ij = pv.ensemble.check(
sim_data1,
sim_data2,
total_energy=False,
filename=f"{plotfile}_{state1_index}_{state2_index}",
bootstrap_error=bootstrap_error,
bootstrap_repetitions=bootstrap_repetitions,
bootstrap_seed=bootstrap_seed,
data_is_uncorrelated=data_is_uncorrelated,
)
quantiles[f"state{state1_index}_state{state2_index}"] = quantiles_ij[0]
except InputError:
print(f"Insufficient overlap between trajectories for states {state1_index} and {state2_index}. Skipping...")
return quantiles
[docs]def set_simulation_data(
state_energies, T_array, state1_index, state2_index
):
"""
Create and set SimulationData objects for a pair of specified states
"""
# Set default UnitData object
default_UnitData = UnitData(
kb=kB.value_in_unit(unit.kilojoule_per_mole/unit.kelvin),
energy_conversion=1,
length_conversion=1,
volume_conversion=1,
temperature_conversion=1,
pressure_conversion=1,
time_conversion=1,
energy_str='KJ/mol',
length_str='nm',
volume_str='nm^3',
temperature_str='K',
pressure_str='bar',
time_str='ps'
)
# State 1
sim_data1 = SimulationData()
sim_data1.observables = ObservableData(
potential_energy=state_energies[state1_index,:],
)
sim_data1.ensemble = EnsembleData(
ensemble='NVT',
energy=state_energies[state1_index,:],
temperature=T_array[state1_index]
)
sim_data1.units = default_UnitData
# State 2
sim_data2 = SimulationData()
sim_data2.observables = ObservableData(
potential_energy=state_energies[state2_index,:],
)
sim_data2.ensemble = EnsembleData(
ensemble='NVT',
energy=state_energies[state2_index,:],
temperature=T_array[state2_index]
)
sim_data2.units = default_UnitData
return sim_data1, sim_data2