env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
xavier_uniform_init (weight:jax.Array, key:<function PRNGKey>)
xavier uniform initialization
xavier_normal_init (weight:jax.Array, key:<function PRNGKey>)
xavier normal initialization
deterministic_init (weight:jax.Array, key:<function PRNGKey>)
constant initialization parameters only for consistency with other initializations
trunc_init (weight:jax.Array, key:<function PRNGKey>)
truncated normal initialization
init_linear_weight (model, init_fn, key)
initialize linear weights of a model with a given init_fn
MultiActMLP (*args, **kwargs)
# test MultiActMLP
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 4))
model = MultiActMLP(4, 2, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
afuncs = [lambda x: x]
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
# test 0 : see if model initializes correctly
test_eq(jnp.all(model.layers[0].weight == 1e-3), True)
# test 1 : see if activations work at all
afuncs = [lambda x: 0]
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
test_eq(jnp.all(y == 0), True)
# test 2 : see if mixing activations works
afuncs = [lambda x: 1, lambda x: 2, lambda x: 3]
weights = model.layers[-1].weight.T
dummy = jnp.ones((5, 18))
dummy = dummy.at[:, 6:].set(2)
dummy = dummy.at[:, 12:].set(3)
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
test_eq(jnp.all(y == dummy @ weights), True)
save (filename, hyperparams, model)
save model and hyperparameters to file
make_mlp (config_dict)
initialize MLP using hyperparameters from config_dict
load (filename, make=<function make_mlp>)
load model and hyperparameters from file
mlp_afunc (x, model, base_act)
MLP that behaves like an activation function
# test mlp_afunc
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((6))
model = eqx.nn.MLP(
in_size=1, out_size=1, width_size=18, depth=1, key=model_key, use_bias=False
)
model = init_linear_weight(model, deterministic_init, init_key)
act = mlp_afunc(x, model, jnp.sin)
test_eq(act.shape, x.shape)
test_eq(act, jnp.sin(x) + model(jnp.ones(1))[0])