From 6d1507bbf82915b961d668fc0438546452fdd58c Mon Sep 17 00:00:00 2001 From: Rob Cornish Date: Mon, 1 Apr 2019 14:27:32 +0100 Subject: [PATCH 1/3] Fixed bug in log_likelihood calculation --- pixconcnn/models/gated_pixelcnn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixconcnn/models/gated_pixelcnn.py b/pixconcnn/models/gated_pixelcnn.py index 30df71e..9849338 100644 --- a/pixconcnn/models/gated_pixelcnn.py +++ b/pixconcnn/models/gated_pixelcnn.py @@ -94,7 +94,7 @@ def log_likelihood(self, device, samples): logits = self.forward(norm_samples) # Note that probs has shape # (batch, num_colors, channels, height, width) - probs = F.softmax(logits, dim=1) + probs = F.log_softmax(logits, dim=1) # Calculate probability of each pixel for i in range(height): @@ -103,9 +103,9 @@ def log_likelihood(self, device, samples): # Get the batch of true values at pixel (k, i, j) true_vals = samples[:, k, i, j] # Get probability assigned by model to true pixel - probs_pixel = probs[:, true_vals, k, i, j][:, 0] - # Add log probs (1e-9 to avoid log(0)) - log_probs += torch.log(probs_pixel + 1e-9) + all_probs_pixel = probs[:, :, k, i, j] + probs_pixel = all_probs_pixel.gather(dim=1, index=true_vals.unsqueeze(1).expand_as(all_probs_pixel))[:, 0] + log_probs += probs_pixel # Reset model to train mode self.train() From f2f94c3c843a91104217f14bf6801cb7e87c4f66 Mon Sep 17 00:00:00 2001 From: Rob Cornish Date: Mon, 1 Apr 2019 14:37:28 +0100 Subject: [PATCH 2/3] Vectorized log_likelihood() --- pixconcnn/models/gated_pixelcnn.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/pixconcnn/models/gated_pixelcnn.py b/pixconcnn/models/gated_pixelcnn.py index 9849338..5b8bf03 100644 --- a/pixconcnn/models/gated_pixelcnn.py +++ b/pixconcnn/models/gated_pixelcnn.py @@ -92,20 +92,9 @@ def log_likelihood(self, device, samples): norm_samples = samples.float() / (self.num_colors - 1) # Calculate pixel probs according to the model logits = self.forward(norm_samples) - # Note that probs has shape - # (batch, num_colors, channels, height, width) - probs = F.log_softmax(logits, dim=1) - - # Calculate probability of each pixel - for i in range(height): - for j in range(width): - for k in range(num_channels): - # Get the batch of true values at pixel (k, i, j) - true_vals = samples[:, k, i, j] - # Get probability assigned by model to true pixel - all_probs_pixel = probs[:, :, k, i, j] - probs_pixel = all_probs_pixel.gather(dim=1, index=true_vals.unsqueeze(1).expand_as(all_probs_pixel))[:, 0] - log_probs += probs_pixel + + all_log_probs = -F.cross_entropy(logits, samples, reduction="none") + log_probs = all_log_probs.sum((1, 2, 3)) # Reset model to train mode self.train() From d068c595d2de3c46b9a09a9ec2136cf31085676b Mon Sep 17 00:00:00 2001 From: Rob Cornish Date: Mon, 1 Apr 2019 16:44:40 +0100 Subject: [PATCH 3/3] Removed now redundant code --- pixconcnn/models/gated_pixelcnn.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pixconcnn/models/gated_pixelcnn.py b/pixconcnn/models/gated_pixelcnn.py index 5b8bf03..bbbcce7 100644 --- a/pixconcnn/models/gated_pixelcnn.py +++ b/pixconcnn/models/gated_pixelcnn.py @@ -71,7 +71,7 @@ def sample(self, device, num_samples=16, temp=1., return_likelihood=False): return samples.cpu() def log_likelihood(self, device, samples): - """Calculates log likelihood of samples under model. + """Calculates log likelihood (in nats) of samples under model. Parameters ---------- @@ -84,10 +84,6 @@ def log_likelihood(self, device, samples): # Set model to evaluation mode self.eval() - num_samples, num_channels, height, width = samples.size() - log_probs = torch.zeros(num_samples) - log_probs = log_probs.to(device) - # Normalize samples before passing through model norm_samples = samples.float() / (self.num_colors - 1) # Calculate pixel probs according to the model