env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
compute_loss (model, x, y, afuncs)
Compute the l2 loss of the model on the given data.
# test compute_loss
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 4))
model = MultiActMLP(4, 2, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
y = jnp.ones((5, 2))
afuncs = [lambda x: 0, lambda x: 0]
loss, grad = compute_loss(model, x, y, afuncs)
test_eq(loss, 1)