env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
hamiltonian_factory (model, afuncs)
Returns a function that computes the Hamiltonian of a given model.
compute_loss (model, x, y, afuncs)
Computes hamilton’s equations to get dqdp and then computes the loss
# test compute_loss
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 2))
model = MultiActMLP(2, 1, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
y = jnp.ones((5, 2))
afuncs = [lambda x: 1, lambda x: 0]
loss, _ = compute_loss(model, x, y, afuncs)
test_eq(loss, 1.0)