From 077c52afc00f66ad368e0f57efe4adfdb42b948e Mon Sep 17 00:00:00 2001 From: hongkai-dai Date: Tue, 12 Apr 2022 12:23:25 -0700 Subject: [PATCH] Add barrier loss to total_loss() --- .../test/test_train_lyapunov_barrier.py | 46 +- .../train_lyapunov_barrier.py | 462 ++++++++++++------ 2 files changed, 358 insertions(+), 150 deletions(-) diff --git a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py index 91496b7a..dce71ab9 100644 --- a/neural_network_lyapunov/test/test_train_lyapunov_barrier.py +++ b/neural_network_lyapunov/test/test_train_lyapunov_barrier.py @@ -232,13 +232,23 @@ def test_total_loss(self): positivity_state_samples = state_samples_all.clone() derivative_state_samples = state_samples_all.clone() derivative_state_samples_next = state_samples_next.clone() + safe_state_samples = torch.empty((0, dut.x_dim()), dtype=dut.dtype()) + unsafe_state_samples = torch.empty((0, dut.x_dim()), dtype=dut.dtype()) + barrier_derivative_state_samples = torch.empty((0, dut.x_dim()), + dtype=dut.dtype()) + total_loss_return = dut.total_loss( positivity_state_samples, derivative_state_samples, state_samples_next, dut.lyapunov_positivity_sample_cost_weight, dut.lyapunov_derivative_sample_cost_weight, dut.lyapunov_positivity_mip_cost_weight, dut.lyapunov_derivative_mip_cost_weight, - dut.boundary_value_gap_mip_cost_weight) + dut.boundary_value_gap_mip_cost_weight, safe_state_samples, + unsafe_state_samples, barrier_derivative_state_samples, + dut.safe_sample_cost_weight, dut.unsafe_sample_cost_weight, + dut.barrier_derivative_sample_cost_weight, + dut.safe_mip_cost_weight, dut.unsafe_mip_cost_weight, + dut.barrier_derivative_mip_cost_weight) self.assertEqual( positivity_state_samples.shape[0] + 1, @@ -520,7 +530,7 @@ def test_solve_barrier_derivative_mip(self): adversarial, dut.barrier_x_star, dut.barrier_c, dut.barrier_epsilon).detach().numpy()) - def test_barrier_loss(self): + def test_barrier_loss1(self): dut = train_lyapunov_barrier.Trainer() dut.add_barrier(self.barrier_system, x_star=(self.system.x_lo * 0.25 + @@ -569,6 +579,38 @@ def test_barrier_loss(self): self.assertGreater(barrier_loss.derivative_state_samples.shape[0], num_derivative_state_samples) + def test_barrier_loss2(self): + dut = train_lyapunov_barrier.Trainer() + dut.add_barrier(self.barrier_system, + x_star=(self.system.x_lo * 0.25 + + self.system.x_up * 0.75), + c=0.1, + barrier_epsilon=0.3) + + safe_state_samples = torch.empty((0, self.barrier_system.system.x_dim), + dtype=self.dtype) + unsafe_state_samples = torch.empty( + (0, self.barrier_system.system.x_dim), dtype=self.dtype) + derivative_state_samples = torch.empty( + (0, self.barrier_system.system.x_dim), dtype=self.dtype) + safe_sample_cost_weight = 1. + unsafe_sample_cost_weight = 2. + derivative_sample_cost_weight = 3. + safe_mip_cost_weight = None + unsafe_mip_cost_weight = None + derivative_mip_cost_weight = None + barrier_loss = dut.compute_barrier_loss( + safe_state_samples, unsafe_state_samples, derivative_state_samples, + safe_sample_cost_weight, unsafe_sample_cost_weight, + derivative_sample_cost_weight, safe_mip_cost_weight, + unsafe_mip_cost_weight, derivative_mip_cost_weight) + self.assertEqual(barrier_loss.safe_sample_loss.item(), 0) + self.assertEqual(barrier_loss.unsafe_sample_loss.item(), 0) + self.assertEqual(barrier_loss.derivative_sample_loss.item(), 0) + self.assertIsNone(barrier_loss.safe_mip_loss) + self.assertIsNone(barrier_loss.unsafe_mip_loss) + self.assertIsNone(barrier_loss.derivative_mip_loss) + class TestTrainValueApproximator(unittest.TestCase): def setUp(self): diff --git a/neural_network_lyapunov/train_lyapunov_barrier.py b/neural_network_lyapunov/train_lyapunov_barrier.py index 788c9afd..ad4b7df1 100644 --- a/neural_network_lyapunov/train_lyapunov_barrier.py +++ b/neural_network_lyapunov/train_lyapunov_barrier.py @@ -83,6 +83,12 @@ def __init__(self): self.safe_regions = [] self.barrier_value_mip_pool_solutions = 1 self.barrier_derivative_mip_pool_solutions = 1 + self.safe_sample_cost_weight = 0. + self.unsafe_sample_cost_weight = 0. + self.barrier_derivative_sample_cost_weight = 0. + self.safe_mip_cost_weight = None + self.unsafe_mip_cost_weight = None + self.barrier_derivative_mip_cost_weight = None # The learning rate of the optimizer self.learning_rate = 0.003 @@ -623,94 +629,14 @@ def __init__(self, loss: torch.Tensor, lyap_loss, barrier_loss): assert (isinstance(barrier_loss, Trainer.BarrierLoss)) self.barrier_loss = barrier_loss - def compute_barrier_loss(self, safe_state_samples, unsafe_state_samples, - derivative_state_samples, safe_sample_cost_weight, - unsafe_sample_cost_weight, - derivative_sample_cost_weight, - safe_mip_cost_weight, unsafe_mip_cost_weight, - derivative_mip_cost_weight) -> BarrierLoss: - barrier_loss = Trainer.BarrierLoss() - - if safe_mip_cost_weight is not None: - safe_mip, barrier_loss.safe_mip_obj, safe_mip_adversarial = \ - self.solve_barrier_value_mip(safe_flag=True) - if safe_mip_cost_weight != 0: - barrier_loss.safe_mip_loss = [ - safe_mip_cost_weight * - mip.compute_objective_from_mip_data_and_solution( - solution_number=0, penalty=1e-13) for mip in safe_mip - ] - else: - barrier_loss.safe_mip_loss = None - barrier_loss.safe_mip_obj = None - safe_mip_adversarial = None - - if unsafe_mip_cost_weight is not None: - unsafe_mip, barrier_loss.unsafe_mip_obj, unsafe_mip_adversarial = \ - self.solve_barrier_value_mip(safe_flag=False) - if unsafe_mip_cost_weight != 0: - barrier_loss.unsafe_mip_loss = [ - unsafe_mip_cost_weight * - mip.compute_objective_from_mip_data_and_solution( - solution_number=0, penalty=1e-13) for mip in unsafe_mip - ] - else: - barrier_loss.unsafe_mip_loss = None - barrier_loss.unsafe_mip_obj = None - unsafe_mip_adversarial = None - - if derivative_mip_cost_weight is not None: - derivative_mip, barrier_loss.derivative_mip_obj, \ - derivative_mip_adversarial = self.solve_barrier_derivative_mip( - ) - if derivative_mip_cost_weight != 0: - barrier_loss.derivative_mip_loss = derivative_mip_cost_weight \ - * derivative_mip.\ - compute_objective_from_mip_data_and_solution( - solution_number=0, penalty=1e-13) - else: - barrier_loss.derivative_mip_loss = None - barrier_loss.derivative_mip_obj = None - derivative_mip_adversarial = None - - barrier_loss.safe_state_samples = safe_state_samples - barrier_loss.unsafe_state_samples = unsafe_state_samples - barrier_loss.derivative_state_samples = derivative_state_samples - if safe_mip_cost_weight is not None and \ - safe_mip_adversarial is not None and \ - len(safe_mip_adversarial) > 0: - barrier_loss.safe_state_samples = torch.cat( - (safe_state_samples, torch.cat(safe_mip_adversarial, dim=0)), - dim=0) - if unsafe_mip_cost_weight is not None and \ - unsafe_mip_adversarial is not None and \ - len(unsafe_mip_adversarial) > 0: - barrier_loss.unsafe_state_samples = torch.cat( - (unsafe_state_samples, torch.cat(unsafe_mip_adversarial, - dim=0)), - dim=0) - if derivative_mip_cost_weight is not None and \ - derivative_mip_adversarial is not None: - barrier_loss.derivative_state_samples = torch.cat( - (derivative_state_samples, derivative_mip_adversarial), dim=0) - barrier_loss.safe_sample_loss, barrier_loss.unsafe_sample_loss, \ - barrier_loss.derivative_sample_loss = self.barrier_sample_loss( - barrier_loss.safe_state_samples[-self.max_sample_pool_size:], - barrier_loss.unsafe_state_samples[-self.max_sample_pool_size:], - barrier_loss.derivative_state_samples[ - -self.max_sample_pool_size:], - safe_sample_cost_weight, unsafe_sample_cost_weight, - derivative_sample_cost_weight) - return barrier_loss - - def total_loss(self, positivity_state_samples, - lyap_derivative_state_samples, - lyap_derivative_state_samples_next, - lyap_positivity_sample_cost_weight, - lyap_derivative_sample_cost_weight, - lyap_positivity_mip_cost_weight, - lyap_derivative_mip_cost_weight, - boundary_value_gap_mip_cost_weight) -> TotalLossReturn: + def compute_lyapunov_loss(self, positivity_state_samples, + lyap_derivative_state_samples, + lyap_derivative_state_samples_next, + lyap_positivity_sample_cost_weight, + lyap_derivative_sample_cost_weight, + lyap_positivity_mip_cost_weight, + lyap_derivative_mip_cost_weight, + boundary_value_gap_mip_cost_weight) -> LyapLoss: """ Compute the total loss as the summation of 1. hinge(-V(xⁱ) + ε₂ |xⁱ - x*|₁) for sampled state xⁱ. @@ -742,7 +668,8 @@ def total_loss(self, positivity_state_samples, """ dtype = self.lyapunov_hybrid_system.system.dtype lyap_loss = Trainer.LyapLoss() - barrier_loss = Trainer.BarrierLoss() + if self.lyapunov_hybrid_system is None: + return lyap_loss if lyap_positivity_mip_cost_weight is not None: lyap_positivity_mip, lyap_loss.positivity_mip_obj,\ positivity_mip_adversarial = self.solve_positivity_mip() @@ -761,8 +688,6 @@ def total_loss(self, positivity_state_samples, lyap_derivative_mip_adversarial = None lyap_derivative_mip_adversarial_next = None - loss = torch.tensor(0., dtype=dtype) - lyap_loss.positivity_mip_loss = torch.tensor(0., dtype=dtype) if lyap_positivity_mip_cost_weight != 0 and\ lyap_positivity_mip_cost_weight is not None: @@ -835,11 +760,139 @@ def total_loss(self, positivity_state_samples, lyap_derivative_state_samples_next_in_pool, lyap_positivity_sample_cost_weight, lyap_derivative_sample_cost_weight) + return lyap_loss + + def compute_barrier_loss(self, safe_state_samples, unsafe_state_samples, + derivative_state_samples, safe_sample_cost_weight, + unsafe_sample_cost_weight, + derivative_sample_cost_weight, + safe_mip_cost_weight, unsafe_mip_cost_weight, + derivative_mip_cost_weight) -> BarrierLoss: + barrier_loss = Trainer.BarrierLoss() + if self.barrier_system is None: + return barrier_loss - loss = lyap_loss.positivity_sample_loss + \ - lyap_loss.derivative_sample_loss + \ - lyap_loss.positivity_mip_loss + lyap_loss.derivative_mip_loss +\ - lyap_loss.gap_mip_loss + if safe_mip_cost_weight is not None: + safe_mip, barrier_loss.safe_mip_obj, safe_mip_adversarial = \ + self.solve_barrier_value_mip(safe_flag=True) + if safe_mip_cost_weight != 0: + barrier_loss.safe_mip_loss = [ + safe_mip_cost_weight * + mip.compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) for mip in safe_mip + ] + else: + barrier_loss.safe_mip_loss = None + barrier_loss.safe_mip_obj = None + safe_mip_adversarial = None + + if unsafe_mip_cost_weight is not None: + unsafe_mip, barrier_loss.unsafe_mip_obj, unsafe_mip_adversarial = \ + self.solve_barrier_value_mip(safe_flag=False) + if unsafe_mip_cost_weight != 0: + barrier_loss.unsafe_mip_loss = [ + unsafe_mip_cost_weight * + mip.compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) for mip in unsafe_mip + ] + else: + barrier_loss.unsafe_mip_loss = None + barrier_loss.unsafe_mip_obj = None + unsafe_mip_adversarial = None + + if derivative_mip_cost_weight is not None: + derivative_mip, barrier_loss.derivative_mip_obj, \ + derivative_mip_adversarial = self.solve_barrier_derivative_mip( + ) + if derivative_mip_cost_weight != 0: + barrier_loss.derivative_mip_loss = derivative_mip_cost_weight \ + * derivative_mip.\ + compute_objective_from_mip_data_and_solution( + solution_number=0, penalty=1e-13) + else: + barrier_loss.derivative_mip_loss = None + barrier_loss.derivative_mip_obj = None + derivative_mip_adversarial = None + + barrier_loss.safe_state_samples = safe_state_samples + barrier_loss.unsafe_state_samples = unsafe_state_samples + barrier_loss.derivative_state_samples = derivative_state_samples + if safe_mip_cost_weight is not None and \ + safe_mip_adversarial is not None and \ + len(safe_mip_adversarial) > 0: + barrier_loss.safe_state_samples = torch.cat( + (safe_state_samples, torch.cat(safe_mip_adversarial, dim=0)), + dim=0) + if unsafe_mip_cost_weight is not None and \ + unsafe_mip_adversarial is not None and \ + len(unsafe_mip_adversarial) > 0: + barrier_loss.unsafe_state_samples = torch.cat( + (unsafe_state_samples, torch.cat(unsafe_mip_adversarial, + dim=0)), + dim=0) + if derivative_mip_cost_weight is not None and \ + derivative_mip_adversarial is not None: + barrier_loss.derivative_state_samples = torch.cat( + (derivative_state_samples, derivative_mip_adversarial), dim=0) + barrier_loss.safe_sample_loss, barrier_loss.unsafe_sample_loss, \ + barrier_loss.derivative_sample_loss = self.barrier_sample_loss( + barrier_loss.safe_state_samples[-self.max_sample_pool_size:], + barrier_loss.unsafe_state_samples[-self.max_sample_pool_size:], + barrier_loss.derivative_state_samples[ + -self.max_sample_pool_size:], + safe_sample_cost_weight, unsafe_sample_cost_weight, + derivative_sample_cost_weight) + return barrier_loss + + def total_loss(self, positivity_state_samples, + lyap_derivative_state_samples, + lyap_derivative_state_samples_next, + lyap_positivity_sample_cost_weight, + lyap_derivative_sample_cost_weight, + lyap_positivity_mip_cost_weight, + lyap_derivative_mip_cost_weight, + boundary_value_gap_mip_cost_weight, safe_state_samples, + unsafe_state_samples, barrier_derivative_state_samples, + safe_sample_cost_weight, unsafe_sample_cost_weight, + barrier_derivative_sample_cost_weight, safe_mip_cost_weight, + unsafe_mip_cost_weight, + barrier_derivative_mip_cost_weight) -> TotalLossReturn: + lyap_loss = self.compute_lyapunov_loss( + positivity_state_samples, lyap_derivative_state_samples, + lyap_derivative_state_samples_next, + lyap_positivity_sample_cost_weight, + lyap_derivative_sample_cost_weight, + lyap_positivity_mip_cost_weight, lyap_derivative_mip_cost_weight, + boundary_value_gap_mip_cost_weight) + + barrier_loss = self.compute_barrier_loss( + safe_state_samples, unsafe_state_samples, + barrier_derivative_state_samples, safe_sample_cost_weight, + unsafe_sample_cost_weight, barrier_derivative_sample_cost_weight, + safe_mip_cost_weight, unsafe_mip_cost_weight, + barrier_derivative_mip_cost_weight) + + def add_loss(total_loss, individual_loss): + if individual_loss is not None: + total_loss += individual_loss + + loss = torch.tensor(0, dtype=self.dtype()) + add_loss(loss, lyap_loss.positivity_sample_loss) + add_loss(loss, lyap_loss.derivative_sample_loss) + add_loss(loss, lyap_loss.positivity_mip_loss) + add_loss(loss, lyap_loss.derivative_mip_loss) + add_loss(loss, lyap_loss.gap_mip_loss) + + add_loss(loss, barrier_loss.safe_sample_loss) + add_loss(loss, barrier_loss.unsafe_sample_loss) + add_loss(loss, barrier_loss.derivative_sample_loss) + if barrier_loss.safe_mip_loss is not None and len( + barrier_loss.safe_mip_loss) > 1: + loss += torch.sum(torch.stack(barrier_loss.safe_mip_loss)) + if barrier_loss.unsafe_mip_loss is not None and len( + barrier_loss.unsafe_mip_loss) > 1: + loss += torch.sum(torch.stack(barrier_loss.unsafe_mip_loss)) + add_loss(loss, barrier_loss.derivative_mip_loss) return Trainer.TotalLossReturn(loss, lyap_loss, barrier_loss) @@ -886,7 +939,10 @@ def _training_params(self): self.R_options.variables() return training_params - def train(self, state_samples_all): + def train(self, + state_samples_all, + safe_state_samples=None, + unsafe_state_samples=None): train_start_time = time.time() if self.output_flag: self.print() @@ -894,16 +950,24 @@ def train(self, state_samples_all): assert (state_samples_all.shape[1] == self.lyapunov_hybrid_system.system.x_dim) positivity_state_samples = state_samples_all.clone() - derivative_state_samples = state_samples_all.clone() + lyap_derivative_state_samples = state_samples_all.clone() + if safe_state_samples is None: + safe_state_samples = torch.empty((0, self.x_dim()), + dtype=self.dtype()) + if unsafe_state_samples is None: + unsafe_state_samples = torch.empty((0, self.x_dim()), + dtype=self.dtype()) + barrier_derivative_state_samples = state_samples_all.clone() if (state_samples_all.shape[0] > 0): - derivative_state_samples_next = torch.stack([ + lyap_derivative_state_samples_next = torch.stack([ self.lyapunov_hybrid_system.system.step_forward( - derivative_state_samples[i]) - for i in range(derivative_state_samples.shape[0]) + lyap_derivative_state_samples[i]) + for i in range(lyap_derivative_state_samples.shape[0]) ], - dim=0) + dim=0) else: - derivative_state_samples_next = torch.empty_like(state_samples_all) + lyap_derivative_state_samples_next = torch.empty_like( + state_samples_all) iter_count = 0 training_params = self._training_params() @@ -927,26 +991,31 @@ def train(self, state_samples_all): # If we train a feedback system, then we will modify the # controller in each iteration, hence the next sample state # changes in each iteration. - if (derivative_state_samples.shape[0] > 0): - derivative_state_samples_next =\ + if (lyap_derivative_state_samples.shape[0] > 0): + lyap_derivative_state_samples_next =\ self.lyapunov_hybrid_system.system.step_forward( - derivative_state_samples) + lyap_derivative_state_samples) else: - derivative_state_samples_next = torch.empty_like( - derivative_state_samples) + lyap_derivative_state_samples_next = torch.empty_like( + lyap_derivative_state_samples) total_loss_return = self.total_loss( - positivity_state_samples, derivative_state_samples, - derivative_state_samples_next, + positivity_state_samples, lyap_derivative_state_samples, + lyap_derivative_state_samples_next, self.lyapunov_positivity_sample_cost_weight, self.lyapunov_derivative_sample_cost_weight, self.lyapunov_positivity_mip_cost_weight, self.lyapunov_derivative_mip_cost_weight, - self.boundary_value_gap_mip_cost_weight) + self.boundary_value_gap_mip_cost_weight, safe_state_samples, + unsafe_state_samples, barrier_derivative_state_samples, + self.safe_sample_cost_weight, self.unsafe_sample_cost_weight, + self.barrier_derivative_sample_cost_weight, + self.safe_mip_cost_weight, self.unsafe_mip_cost_weight, + self.barrier_derivative_mip_cost_weight) positivity_state_samples = \ total_loss_return.lyap_loss.positivity_state_samples - derivative_state_samples = \ + lyap_derivative_state_samples = \ total_loss_return.lyap_loss.derivative_state_samples - derivative_state_samples_next = \ + lyap_derivative_state_samples_next = \ total_loss_return.lyap_loss.derivative_state_samples_next if self.enable_wandb: @@ -989,39 +1058,78 @@ def train(self, state_samples_all): total_loss_return.lyap_loss.positivity_mip_obj, total_loss_return.lyap_loss.derivative_mip_obj) - def train_lyapunov_on_samples(self, state_samples_all, num_epochs, - batch_size): - """ - Train a ReLU network on given state samples (not the adversarial states - found by MIP). The loss function is the weighted sum of the lyapunov - positivity condition violation and the lyapunov derivative condition - violation on these samples. We stop the training when either the - maximum iteration is reached, or when many consecutive iterations the - MIP costs keeps increasing (which means the network overfits to the - training data). Return the best network (the one with the minimal - MIP loss) found so far. - @param state_samples_all A torch tensor, state_samples_all[i] is the - i'th sample - """ + def x_dim(self): + if self.lyapunov_hybrid_system is not None: + return self.lyapunov_hybrid_system.system.x_dim + elif self.barrier_system is not None: + return self.barrier_system.system.x_dim + + def dtype(self): + if self.lyapunov_hybrid_system is not None: + return self.lyapunov_hybrid_system.system.dtype + elif self.barrier_system is not None: + return self.barrier_system.system.dtype + + def train_on_samples(self, state_samples_all, safe_samples, unsafe_samples, + num_epochs, batch_size): assert (isinstance(state_samples_all, torch.Tensor)) assert (state_samples_all.shape[1] == self.lyapunov_hybrid_system.system.x_dim) + if safe_samples is None: + safe_samples = torch.empty((0, self.x_dim()), dtype=self.dtype()) + if unsafe_samples is None: + unsafe_samples = torch.empty((0, self.x_dim()), dtype=self.dtype()) + best_loss = np.inf training_params = self._training_params() optimizer = torch.optim.Adam(training_params, lr=self.learning_rate) - dataset = torch.utils.data.TensorDataset(state_samples_all) - train_set_size = int(len(dataset) * 0.8) - test_set_size = len(dataset) - train_set_size - train_dataset, test_dataset = torch.utils.data.random_split( - dataset, [train_set_size, test_set_size]) - data_loader = torch.utils.data.DataLoader(train_dataset, - batch_size=batch_size, - shuffle=True) - test_state_samples = test_dataset[:][0] + + def split_data(states): + dataset = torch.utils.data.TensorDataset(states) + train_set_size = int(len(dataset) * 0.8) + test_set_size = len(dataset) - train_set_size + train_dataset, test_dataset = torch.utils.data.random_split( + dataset, [train_set_size, test_set_size]) + data_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=batch_size, + shuffle=True) + test_state_samples = test_dataset[:][0] + return data_loader, test_state_samples + + data_loader, test_state_samples = split_data(state_samples_all) + if safe_samples.shape[0] > 0: + safe_data_loader, test_safe_state_samples = split_data( + safe_samples) + else: + safe_data_loader = None + test_safe_state_samples = torch.empty((0, self.x_dim()), + dtype=self.dtype()) + if unsafe_samples.shape[0] > 0: + unsafe_data_loader, test_unsafe_state_samples = split_data( + unsafe_samples) + else: + unsafe_data_loader = None + test_unsafe_state_samples = torch.empty((0, self.x_dim()), + dtype=self.dtype()) for epoch in range(num_epochs): running_loss = 0. - for _, batch_data in enumerate(data_loader): + if safe_data_loader is None and unsafe_data_loader is None: + data_loaders = data_loader + else: + data_loaders = zip(data_loader, safe_data_loader, + unsafe_data_loader) + for _, batch_data in enumerate(data_loaders): state_samples_batch = batch_data[0] + if safe_data_loader is not None: + safe_state_samples_batch = batch_data[1] + else: + safe_state_samples_batch = torch.empty((0, self.x_dim()), + dtype=self.dtype()) + if unsafe_data_loader is not None: + unsafe_state_samples_batch = batch_data[2] + else: + unsafe_state_samples_batch = torch.empty( + (0, self.x_dim()), dtype=self.dtype()) optimizer.zero_grad() state_samples_next = torch.stack([ self.lyapunov_hybrid_system.system.step_forward( @@ -1037,7 +1145,17 @@ def train_lyapunov_on_samples(self, state_samples_all, num_epochs, self.lyapunov_derivative_sample_cost_weight, lyap_positivity_mip_cost_weight=None, lyap_derivative_mip_cost_weight=None, - boundary_value_gap_mip_cost_weight=0) + boundary_value_gap_mip_cost_weight=0, + safe_state_samples=safe_state_samples_batch, + unsafe_state_samples=unsafe_state_samples_batch, + barrier_derivative_state_samples=state_samples_batch, + safe_sample_cost_weight=self.safe_sample_cost_weight, + unsafe_sample_cost_weight=self.unsafe_sample_cost_weight, + barrier_derivative_sample_cost_weight=self. + barrier_derivative_sample_cost_weight, + safe_mip_cost_weight=None, + unsafe_mip_cost_weight=None, + barrier_derivative_mip_cost_weight=None) total_loss_return.loss.backward() optimizer.step() running_loss += total_loss_return.loss.item() @@ -1057,27 +1175,75 @@ def train_lyapunov_on_samples(self, state_samples_all, num_epochs, self.lyapunov_derivative_sample_cost_weight, lyap_positivity_mip_cost_weight=None, lyap_derivative_mip_cost_weight=None, - boundary_value_gap_mip_cost_weight=0) + boundary_value_gap_mip_cost_weight=0, + safe_state_samples=test_safe_state_samples, + unsafe_state_samples=test_unsafe_state_samples, + barrier_derivative_state_samples=test_state_samples, + safe_sample_cost_weight=self.safe_sample_cost_weight, + unsafe_sample_cost_weight=self.unsafe_sample_cost_weight, + barrier_derivative_sample_cost_weight=self. + barrier_derivative_sample_cost_weight, + safe_mip_cost_weight=None, + unsafe_mip_cost_weight=None, + barrier_derivative_mip_cost_weight=None) test_loss = test_loss_return.loss print(f"epoch {epoch}, training loss " + f"{running_loss / len(data_loader)}, test loss " + f"{test_loss.item()}") if test_loss.item() < best_loss: best_loss = test_loss.item() - best_lyapunov_relu = copy.deepcopy( - self.lyapunov_hybrid_system.lyapunov_relu) - if isinstance(self.lyapunov_hybrid_system.system, - feedback_system.FeedbackSystem): + if self.lyapunov_hybrid_system is not None: + best_lyapunov_relu = copy.deepcopy( + self.lyapunov_hybrid_system.lyapunov_relu) + if self.barrier_system is not None: + best_barrier_relu = copy.deepcopy( + self.barrier_system.barrier_relu) + if self.lyapunov_hybrid_system is not None and isinstance( + self.lyapunov_hybrid_system.system, + feedback_system.FeedbackSystem): best_controller_relu = copy.deepcopy( self.lyapunov_hybrid_system.system.controller_network) + if self.barrier_system is not None and isinstance( + self.barrier_system.system, + feedback_system.FeedbackSystem): + best_controller_relu = copy.deepcopy( + self.barrier_system.system.controller_network) print(f"best loss {best_loss}") - self.lyapunov_hybrid_system.lyapunov_relu.load_state_dict( - best_lyapunov_relu.state_dict()) - if isinstance(self.lyapunov_hybrid_system.system, - feedback_system.FeedbackSystem): - self.lyapunov_hybrid_system.system.controller_network.\ - load_state_dict(best_controller_relu.state_dict()) + if self.lyapunov_hybrid_system is not None: + self.lyapunov_hybrid_system.lyapunov_relu.load_state_dict( + best_lyapunov_relu.state_dict()) + if isinstance(self.lyapunov_hybrid_system.system, + feedback_system.FeedbackSystem): + self.lyapunov_hybrid_system.system.controller_network.\ + load_state_dict(best_controller_relu.state_dict()) + if self.barrier_system is not None: + self.barrier_system.barrier_relu.load_state_dict( + best_barrier_relu.state_dict()) + if isinstance(self.barrier_system.system, + feedback_system.FeedbackSystem): + self.barrier_system.system.controller_network.load_state_dict( + best_controller_relu.state_dict()) + + def train_lyapunov_on_samples(self, state_samples_all, num_epochs, + batch_size): + """ + Train a ReLU network on given state samples (not the adversarial states + found by MIP). The loss function is the weighted sum of the lyapunov + positivity condition violation and the lyapunov derivative condition + violation on these samples. We stop the training when either the + maximum iteration is reached, or when many consecutive iterations the + MIP costs keeps increasing (which means the network overfits to the + training data). Return the best network (the one with the minimal + MIP loss) found so far. + @param state_samples_all A torch tensor, state_samples_all[i] is the + i'th sample + """ + return self.train_on_samples(state_samples_all, + safe_samples=None, + unsafe_samples=None, + num_epochs=num_epochs, + batch_size=batch_size) class AdversarialTrainingOptions: def __init__(self):