Calculate member based ensemble loss without running the forward pass twice. #26318
Unanswered
HarveySouth
asked this question in
Q&A
Replies: 1 comment
-
I've moved my simple FFNN from linen to nnx and come up with: @nnx.jit
def jit_loss_calculation(member_prediction, training_labels, non_current_member_outputs):
member_error = (jnp.square(member_prediction - training_labels))
member_contribution_to_ensemble = jnp.divide(member_prediction, n_ensemble_members)
ensemble_centroid = member_contribution_to_ensemble + non_current_member_outputs
member_diversity = jnp.square(ensemble_centroid - member_prediction)
full_loss = member_error - (resolved_lambda * member_diversity)
return full_loss.mean()
def run_ensemble_member_loss_and_grad_in_parallel(training_input, training_labels, shared_data, lock_memory, condition, member_index, model):
def all_predictions_set():
return all(lock_memory)
def ncl_member_loss(model):
member_prediction = model(training_input).squeeze()
with condition:
shared_data.at[member_index].set(member_prediction)
lock_memory[member_index] = True
condition.notify_all()
with condition:
condition.wait_for(all_predictions_set)
jax.block_until_ready(shared_data)
non_current_member_outputs = jnp.sum( jnp.concatenate((shared_data[:member_index], shared_data[member_index+1:])), axis=0) / n_ensemble_members
return jit_loss_calculation(member_prediction, training_labels, non_current_member_outputs)
return nnx.value_and_grad(ncl_member_loss)(model)
def setup_parallel_execution(training_input, training_labels, models):
shared_data = jnp.zeros((n_ensemble_members, batch_size))
lock_memory = [False] * n_ensemble_members
condition = threading.Condition() # Use threading.Condition to try and avoid deadlock with jax
with concurrent.futures.ThreadPoolExecutor(max_workers=n_ensemble_members) as executor:
futures = [executor.submit(run_ensemble_member_loss_and_grad_in_parallel,
training_input, training_labels, shared_data, lock_memory, condition, member_index, model) for member_index, model in enumerate(models)]
losses_and_grads = [future.result() for future in futures]
losses, grads = zip(*losses_and_grads)
return losses, grads with for epoch in range(epoch_num):
...
for step, (batch_x, batch_y) in enumerate(training_set):
...
losses, grads = setup_parallel_execution(batch_x, batch_y, ensemble_models)
for i in range(len(member_optimizers)):
member_optimizers[i].update(grads[i]) Validity TBD, and definitely not the best solution, but seems to work and doesn't require running the ensemble more than necessary |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I want to implement negative correlation learning (NCL) in JAX. NCL is a regression ensemble training algorithm which updates each member in the ensemble with it's own loss function: the squared error between the member prediction and the target value AND the squared error between the ensemble output and the member prediction
Ideally I can:
I'm having two difficulties:
I solved this inefficiently in PyTorch by ignoring the second difficulty, and just running the ensemble twice:
I've been able to run the ensemble members in parallel with vmap as I did in python, but I haven't been able to come up with an alternative approach to efficiently running the training step with JAX and looking for help
Beta Was this translation helpful? Give feedback.
All reactions