About this Abstract |
Meeting |
Materials Science & Technology 2020
|
Symposium
|
Ceramics and Glasses Simulations and Machine Learning
|
Presentation Title |
JAX, M.D.: End-to-End Differentiable, Hardware Accelerated, Molecular Dynamics in Pure Python |
Author(s) |
Samuel Schoenholz, Ekin Dogus Cubuk |
On-Site Speaker (Planned) |
Samuel Schoenholz |
Abstract Scope |
Molecular Dynamics (MD) software is used across a vast range of subjects from physics and materials science to biochemistry and drug discovery. Most MD software involves significant use of handwritten derivatives and code reuse across C++, FORTRAN, and CUDA. In this work we bring the substantial advances in software that have taken place in machine learning to MD with JAX, M.D. (JAX MD). JAX MD is an end-to-end differentiable MD package written entirely in Python that can be just-in-time compiled to CPU, GPU, or TPU. JAX MD allows researchers to iterate extremely quickly and lets researchers easily incorporate machine learning models into their workflows. In addition to making workloads easier, JAX MD allows researchers to take derivatives through whole-simulations to design Physical systems with desirable properties. We discuss the architecture of JAX MD through several vignettes with an eye towards glass physics. Code available at www.github.com/google/jax-md. |