baseline mlp

Conventional fully connected neural network for metalearning trajectory prediction.
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
from jaxDiversity.mlp import MultiActMLP, deterministic_init, init_linear_weight

source

compute_loss

 compute_loss (model, x, y, afuncs)

Compute the l2 loss of the model on the given data.

# test compute_loss
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 4))

model = MultiActMLP(4, 2, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
y = jnp.ones((5, 2))

afuncs = [lambda x: 0, lambda x: 0]

loss, grad = compute_loss(model, x, y, afuncs)
test_eq(loss, 1)