Hamiltonian Neural Network

Hamiltonian neural network for metalearning trajectory prediction.
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform

source

hamiltonian_factory

 hamiltonian_factory (model, afuncs)

Returns a function that computes the Hamiltonian of a given model.


source

compute_loss

 compute_loss (model, x, y, afuncs)

Computes hamilton’s equations to get dqdp and then computes the loss

# test compute_loss
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 2))

model = MultiActMLP(2, 1, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
y = jnp.ones((5, 2))

afuncs = [lambda x: 1, lambda x: 0]

loss, _ = compute_loss(model, x, y, afuncs)
test_eq(loss, 1.0)