-
Notifications
You must be signed in to change notification settings - Fork 113
Open
Description
Anyone keen to build this one?
It's supposed to perform pretty well: https://ieeexplore.ieee.org/document/9950236
I had a crack at it but keep getting FFFFF hashes. I think I'm way off...
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications.resnet50 import preprocess_input, ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, CenterCrop
from scipy.linalg import hadamard
import random
import os
# Function to generate hash centers using a Hadamard matrix
def generate_hash_centers(hash_size):
assert (hash_size & (hash_size - 1) == 0) and hash_size != 0, "Hash size must be a power of 2"
H = hadamard(hash_size)
return np.where(H > 0, 1, 0)
# Custom loss function
import tensorflow as tf
def hamming_distance(tensor):
"""Compute pairwise Hamming distance for a batch of binary vectors."""
x = tf.cast(tensor, dtype=tf.int32)
x_expand = tf.expand_dims(x, 2) # Expand to make it a 3D tensor
x_t = tf.transpose(x_expand, [1, 0, 2])
distances = tf.math.reduce_sum(tf.math.abs(x_expand - x_t), axis=-1)
return distances
def custom_loss(hash_centers, margin=0.5, lambda_dist=0.1):
"""
Custom loss function incorporating distinct quantization with central similarity.
- hash_centers: Predefined hash centers for each class
- margin: Minimum desired Hamming distance between different class hash outputs
- lambda_dist: Weighting factor for the distinct quantization component of the loss
"""
hash_centers_tensor = tf.constant(hash_centers, dtype=tf.float32)
def loss(y_true, y_pred):
# Convert predictions to binary
y_pred_binary = tf.round(y_pred) # Threshold predictions to 0 or 1
# Central similarity loss
centers = tf.gather(hash_centers_tensor, tf.cast(y_true, tf.int32))
central_similarity_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(centers, y_pred))
# Calculate pairwise Hamming distances for binary predictions
# Expanded predictions to compare each pair
expanded_pred = tf.expand_dims(y_pred_binary, 0)
transposed_pred = tf.expand_dims(y_pred_binary, 1)
# Calculate Hamming distance
hamming_distances = tf.reduce_sum(tf.abs(expanded_pred - transposed_pred), axis=2)
# Mask for distinct quantization: exclude self and same-class comparisons
batch_size = tf.shape(y_pred)[0]
mask_self = 1 - tf.eye(batch_size)
labels_equal = tf.equal(tf.expand_dims(y_true, 0), tf.expand_dims(y_true, 1))
mask_class = 1 - tf.cast(labels_equal, dtype=tf.float32)
mask = mask_self * mask_class
# Distinct quantization loss
penalties = tf.maximum(0., margin - tf.cast(hamming_distances, tf.float32))
distinct_loss = tf.reduce_sum(penalties * mask) / (tf.reduce_sum(mask) + 1e-8)
# Combine losses
return central_similarity_loss + lambda_dist * distinct_loss
return loss
# Model creation function
def create_model(hash_size=64):
base_model = ResNet50(include_top=False, input_shape=(224, 224, 3), pooling='avg')
hash_layer = Dense(hash_size, activation='sigmoid')
model = Sequential([base_model, hash_layer])
return model
def add_noise(img):
VARIABILITY = 25
deviation = VARIABILITY*random.random()
noise = np.random.normal(0, deviation, img.shape)
img += noise
np.clip(img, 0., 255.)
return img
def preprocess_image(img):
img = add_noise(img)
img = preprocess_input(img)
return img
# Function to preprocess images for training and generate augmented images
def preprocess_images(image_directory, batch_size=32):
datagen = ImageDataGenerator(
preprocessing_function=preprocess_image,
rotation_range=10,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
generator = datagen.flow_from_directory(
image_directory,
target_size=(224, 224),
batch_size=batch_size,
class_mode='sparse' # Assuming sparse labels for hash centers
)
if generator.samples == 0:
print("No images found in specified directory.")
else:
print(f"Found {generator.samples} images belonging to {generator.num_classes} classes.")
return generator
# Load or train model
def get_model(hash_size, hash_centers, dataset_path=None):
model_path = f'resnet50_hash_model_{hash_size}.keras'
if os.path.exists(model_path):
# Load model without specifying custom loss
model = tf.keras.models.load_model(model_path, compile=False)
# After loading, recompile the model with the custom loss
model.compile(optimizer='adam', loss=custom_loss(hash_centers))
else:
model = create_model(hash_size)
model.compile(optimizer='adam', loss=custom_loss(hash_centers))
if dataset_path:
train_generator = preprocess_images(dataset_path)
# print(train_generator)
model.fit(train_generator, epochs=2)
model.save(model_path, save_format='tf')
return model
# Function to generate hash for a single image
def generate_hash(model, preprocessed_img):
predictions = model.predict(preprocessed_img)
binary_hash = np.where(predictions > 0.5, 1, 0)
return binary_hash
# Function to convert binary hash to hexadecimal
def binary_to_hex(binary_hash):
return ''.join(format(x, '02x') for x in np.packbits(binary_hash[0]))
# Main execution setup
if __name__ == "__main__":
hash_size = 32 # Using a 1024-bit hash
hash_centers = generate_hash_centers(hash_size)
print("Type of hash_centers:", type(hash_centers))
print("Shape of hash_centers:", hash_centers.shape)
# Assuming an image path for demonstration; replace with your actual image path
img_path = 'image.jpg'
img = load_img(img_path, target_size=(224, 224))
# # Preprocess the image for the model
preprocessed_img = preprocess_input(img_to_array(img)) # Keep batch dimension for the model
print("Type of preprocessed_img:", type(preprocessed_img))
print("Shape of preprocessed_img:", preprocessed_img.shape)
# Add batch dimension
preprocessed_img = np.expand_dims(preprocessed_img, axis=0)
# Get or train the model
model = get_model(hash_size, hash_centers, dataset_path='./test/')
# predict
predictions = model.predict(preprocessed_img)
print("Raw predictions:", predictions)
# Generate hash for the provided image
hash_code = generate_hash(model, preprocessed_img)
hex_hash = binary_to_hex(hash_code)
print("Generated Hash for the Image (Hex):", hex_hash)
Metadata
Metadata
Assignees
Labels
No labels