env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
make_step (model, x, y, afuncs, optim, opt_state, compute_loss)
inner_opt (model, train_data, test_data, afuncs, opt, loss_fn, config, training=False, verbose=False)
inner optimization loop
# test inner_opt
dev_inner_config = InnerConfig(
test_train_split=0.8,
input_dim=2,
output_dim=2,
hidden_layer_sizes=[18],
batch_size=64,
epochs=2,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42,
)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(
1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
test_dataset = DummyDataset(
1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
train_dataloader = NumpyLoader(
train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
opt = optax.rmsprop(
learning_rate=dev_inner_config.lr,
momentum=dev_inner_config.mu,
decay=dev_inner_config.l2_reg,
)
model = MultiActMLP(
dev_inner_config.input_dim,
dev_inner_config.output_dim,
dev_inner_config.hidden_layer_sizes,
model_key,
bias=False,
)
logging.info("Baseline NN inner loop test")
baselineNN, opt_state, inner_results = inner_opt(
model=model,
train_data=train_dataloader,
test_data=test_dataloader,
afuncs=afuncs,
opt=opt,
loss_fn=compute_loss_baseline,
config=dev_inner_config,
training=True,
verbose=True,
)
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Baseline NN inner loop test
INFO:root:Epoch 000 | Train Loss: 1.6608e-01 | Test Loss: 1.7718e-01 | Grad Norm: 4.0744e-01
INFO:root:Epoch 001 | Train Loss: 1.4460e-01 | Test Loss: 1.2461e-01 | Grad Norm: 1.6079e-01
# test inner_opt
dev_inner_config = InnerConfig(
test_train_split=0.8,
input_dim=2,
output_dim=1,
hidden_layer_sizes=[18],
batch_size=64,
epochs=2,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42,
)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
train_dataloader = NumpyLoader(
train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
opt = optax.rmsprop(
learning_rate=dev_inner_config.lr,
momentum=dev_inner_config.mu,
decay=dev_inner_config.l2_reg,
)
model = MultiActMLP(
dev_inner_config.input_dim,
dev_inner_config.output_dim,
dev_inner_config.hidden_layer_sizes,
model_key,
bias=False,
)
logging.info("Hamiltonian NN inner loop test")
HNN, opt_state, inner_results = inner_opt(
model=model,
train_data=train_dataloader,
test_data=test_dataloader,
afuncs=afuncs,
opt=opt,
loss_fn=compute_loss_hnn,
config=dev_inner_config,
training=True,
verbose=True,
)
INFO:root:Hamiltonian NN inner loop test
INFO:root:Epoch 000 | Train Loss: 9.8625e-02 | Test Loss: 8.2503e-02 | Grad Norm: 2.8775e-01
INFO:root:Epoch 001 | Train Loss: 7.7401e-02 | Test Loss: 7.7187e-02 | Grad Norm: 1.0838e-01
outer_loss (outer_models, inner_model, x, y, loss_fn, base_act)
outer_step (outer_models, inner_model, x, y, meta_opt, meta_opt_state, loss_fn, base_act)
outer_opt (train_dataloader, test_dataloader, loss_fn, inner_config, outer_config, opt, meta_opt, save_path=None)
outer optimization loop
# test outer_opt Baseline
inner_config = InnerConfig(
test_train_split=0.8,
input_dim=2,
output_dim=2,
hidden_layer_sizes=[32],
batch_size=64,
epochs=5,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42,
)
outer_config = OuterConfig(
input_dim=1,
output_dim=1,
hidden_layer_sizes=[18],
batch_size=1,
steps=2,
print_every=1,
lr=1e-3,
mu=0.9,
seed=24,
)
train_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
test_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
train_dataloader = NumpyLoader(
train_dataset, batch_size=inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
test_dataset, batch_size=inner_config.batch_size, shuffle=True
)
opt = optax.rmsprop(
learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg
)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)
logging.info("Baseline NN outer loop test")
baseline_acts, baseline_stats = outer_opt(
train_dataloader,
test_dataloader,
compute_loss_baseline,
inner_config,
outer_config,
opt,
meta_opt,
save_path=None,
)
INFO:root:Baseline NN outer loop test
INFO:root:Step 000 | Train Loss: 8.4353e-02 | Test Loss: 8.5589e-02 | Grad Norm: 6.3488e-01
INFO:root:Step 001 | Train Loss: 9.1323e-02 | Test Loss: 8.8853e-02 | Grad Norm: 2.6783e-01
# test outer_opt HNN
inner_config = InnerConfig(
test_train_split=0.8,
input_dim=2,
output_dim=1,
hidden_layer_sizes=[32],
batch_size=64,
epochs=5,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42,
)
outer_config = OuterConfig(
input_dim=1,
output_dim=1,
hidden_layer_sizes=[18],
batch_size=1,
steps=2,
print_every=1,
lr=1e-3,
mu=0.9,
seed=24,
)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(
train_dataset, batch_size=inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
test_dataset, batch_size=inner_config.batch_size, shuffle=True
)
opt = optax.rmsprop(
learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg
)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)
logging.info("Hamiltonian NN outer loop test")
HNN_acts, HNN_stats = outer_opt(
train_dataloader,
test_dataloader,
compute_loss_hnn,
inner_config,
outer_config,
opt,
meta_opt,
save_path=None,
)
INFO:root:Hamiltonian NN outer loop test
INFO:root:Step 000 | Train Loss: 1.1194e-01 | Test Loss: 9.4442e-02 | Grad Norm: 8.8605e-01
INFO:root:Step 001 | Train Loss: 8.8917e-02 | Test Loss: 9.3278e-02 | Grad Norm: 7.9053e-01