Loops

Inner and outer loops for metalearning
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
# Configure the logger
logging.basicConfig(level=logging.INFO)

source

make_step

 make_step (model, x, y, afuncs, optim, opt_state, compute_loss)

source

inner_opt

 inner_opt (model, train_data, test_data, afuncs, opt, loss_fn, config,
            training=False, verbose=False)

inner optimization loop

# test inner_opt
dev_inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=2,
    hidden_layer_sizes=[18],
    batch_size=64,
    epochs=2,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(
    1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
test_dataset = DummyDataset(
    1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)

opt = optax.rmsprop(
    learning_rate=dev_inner_config.lr,
    momentum=dev_inner_config.mu,
    decay=dev_inner_config.l2_reg,
)
model = MultiActMLP(
    dev_inner_config.input_dim,
    dev_inner_config.output_dim,
    dev_inner_config.hidden_layer_sizes,
    model_key,
    bias=False,
)
logging.info("Baseline NN inner loop test")
baselineNN, opt_state, inner_results = inner_opt(
    model=model,
    train_data=train_dataloader,
    test_data=test_dataloader,
    afuncs=afuncs,
    opt=opt,
    loss_fn=compute_loss_baseline,
    config=dev_inner_config,
    training=True,
    verbose=True,
)
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Baseline NN inner loop test
INFO:root:Epoch 000 | Train Loss: 1.6608e-01 | Test Loss: 1.7718e-01 | Grad Norm: 4.0744e-01
INFO:root:Epoch 001 | Train Loss: 1.4460e-01 | Test Loss: 1.2461e-01 | Grad Norm: 1.6079e-01
# test inner_opt
dev_inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=1,
    hidden_layer_sizes=[18],
    batch_size=64,
    epochs=2,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
opt = optax.rmsprop(
    learning_rate=dev_inner_config.lr,
    momentum=dev_inner_config.mu,
    decay=dev_inner_config.l2_reg,
)
model = MultiActMLP(
    dev_inner_config.input_dim,
    dev_inner_config.output_dim,
    dev_inner_config.hidden_layer_sizes,
    model_key,
    bias=False,
)

logging.info("Hamiltonian NN inner loop test")
HNN, opt_state, inner_results = inner_opt(
    model=model,
    train_data=train_dataloader,
    test_data=test_dataloader,
    afuncs=afuncs,
    opt=opt,
    loss_fn=compute_loss_hnn,
    config=dev_inner_config,
    training=True,
    verbose=True,
)
INFO:root:Hamiltonian NN inner loop test
INFO:root:Epoch 000 | Train Loss: 9.8625e-02 | Test Loss: 8.2503e-02 | Grad Norm: 2.8775e-01
INFO:root:Epoch 001 | Train Loss: 7.7401e-02 | Test Loss: 7.7187e-02 | Grad Norm: 1.0838e-01

source

outer_loss

 outer_loss (outer_models, inner_model, x, y, loss_fn, base_act)

source

outer_step

 outer_step (outer_models, inner_model, x, y, meta_opt, meta_opt_state,
             loss_fn, base_act)

source

outer_opt

 outer_opt (train_dataloader, test_dataloader, loss_fn, inner_config,
            outer_config, opt, meta_opt, save_path=None)

outer optimization loop

# test outer_opt Baseline
inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=2,
    hidden_layer_sizes=[32],
    batch_size=64,
    epochs=5,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
outer_config = OuterConfig(
    input_dim=1,
    output_dim=1,
    hidden_layer_sizes=[18],
    batch_size=1,
    steps=2,
    print_every=1,
    lr=1e-3,
    mu=0.9,
    seed=24,
)
train_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
test_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=inner_config.batch_size, shuffle=True
)

opt = optax.rmsprop(
    learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg
)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

logging.info("Baseline NN outer loop test")
baseline_acts, baseline_stats = outer_opt(
    train_dataloader,
    test_dataloader,
    compute_loss_baseline,
    inner_config,
    outer_config,
    opt,
    meta_opt,
    save_path=None,
)
INFO:root:Baseline NN outer loop test
INFO:root:Step 000 | Train Loss: 8.4353e-02 | Test Loss: 8.5589e-02 | Grad Norm: 6.3488e-01
INFO:root:Step 001 | Train Loss: 9.1323e-02 | Test Loss: 8.8853e-02 | Grad Norm: 2.6783e-01
# test outer_opt HNN
inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=1,
    hidden_layer_sizes=[32],
    batch_size=64,
    epochs=5,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
outer_config = OuterConfig(
    input_dim=1,
    output_dim=1,
    hidden_layer_sizes=[18],
    batch_size=1,
    steps=2,
    print_every=1,
    lr=1e-3,
    mu=0.9,
    seed=24,
)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=inner_config.batch_size, shuffle=True
)

opt = optax.rmsprop(
    learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg
)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

logging.info("Hamiltonian NN outer loop test")
HNN_acts, HNN_stats = outer_opt(
    train_dataloader,
    test_dataloader,
    compute_loss_hnn,
    inner_config,
    outer_config,
    opt,
    meta_opt,
    save_path=None,
)
INFO:root:Hamiltonian NN outer loop test
INFO:root:Step 000 | Train Loss: 1.1194e-01 | Test Loss: 9.4442e-02 | Grad Norm: 8.8605e-01
INFO:root:Step 001 | Train Loss: 8.8917e-02 | Test Loss: 9.3278e-02 | Grad Norm: 7.9053e-01