MLP

module for MLP that takes in multiple activations
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform

source

xavier_uniform_init

 xavier_uniform_init (weight:jax.Array, key:<function PRNGKey>)

xavier uniform initialization


source

xavier_normal_init

 xavier_normal_init (weight:jax.Array, key:<function PRNGKey>)

xavier normal initialization


source

deterministic_init

 deterministic_init (weight:jax.Array, key:<function PRNGKey>)

constant initialization parameters only for consistency with other initializations


source

trunc_init

 trunc_init (weight:jax.Array, key:<function PRNGKey>)

truncated normal initialization


source

init_linear_weight

 init_linear_weight (model, init_fn, key)

initialize linear weights of a model with a given init_fn


source

MultiActMLP

 MultiActMLP (*args, **kwargs)
# test MultiActMLP
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 4))
model = MultiActMLP(4, 2, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
afuncs = [lambda x: x]
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
# test 0 : see if model initializes correctly
test_eq(jnp.all(model.layers[0].weight == 1e-3), True)

# test 1 : see if activations work at all
afuncs = [lambda x: 0]
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
test_eq(jnp.all(y == 0), True)

# test 2 : see if mixing activations works
afuncs = [lambda x: 1, lambda x: 2, lambda x: 3]
weights = model.layers[-1].weight.T
dummy = jnp.ones((5, 18))
dummy = dummy.at[:, 6:].set(2)
dummy = dummy.at[:, 12:].set(3)
y, _ = jax.vmap(model, in_axes=(0, None))(x, afuncs)
test_eq(jnp.all(y == dummy @ weights), True)

source

save

 save (filename, hyperparams, model)

save model and hyperparameters to file


source

make_mlp

 make_mlp (config_dict)

initialize MLP using hyperparameters from config_dict


source

load

 load (filename, make=<function make_mlp>)

load model and hyperparameters from file


source

mlp_afunc

 mlp_afunc (x, model, base_act)

MLP that behaves like an activation function

# test mlp_afunc
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((6))
model = eqx.nn.MLP(
    in_size=1, out_size=1, width_size=18, depth=1, key=model_key, use_bias=False
)
model = init_linear_weight(model, deterministic_init, init_key)
act = mlp_afunc(x, model, jnp.sin)
test_eq(act.shape, x.shape)
test_eq(act, jnp.sin(x) + model(jnp.ones(1))[0])