diff --git a/chambers/callbacks.py b/chambers/callbacks.py index e401da2..2325dc2 100644 --- a/chambers/callbacks.py +++ b/chambers/callbacks.py @@ -1,7 +1,15 @@ +import json +import os +import datetime + import faiss -import numpy as np import tensorflow as tf +from chambers.models.base import PredictReturnYModel, set_predict_return_y +from chambers.models.bloodhound import batch_predict_pairs +from chambers.utils.ranking import rank_labels + + def extract_features(dataset, model): features = [] labels = [] @@ -14,8 +22,16 @@ def extract_features(dataset, model): labels = tf.concat(labels, axis=0) return features.numpy(), labels.numpy() + class GlobalRankingMetricCallback(tf.keras.callbacks.Callback): - def __init__(self, dataset: tf.data.Dataset, metric_funcs, feature_dim=None, name="ranking_metrics", use_gpu=False): + def __init__( + self, + dataset: tf.data.Dataset, + metric_funcs, + feature_dim=None, + name="ranking_metrics", + use_gpu=False, + ): super().__init__() self.dataset = dataset self.metric_funcs = metric_funcs @@ -33,7 +49,9 @@ def on_epoch_end(self, epoch, logs=None): labels = labels.astype(int) self.index.add_with_ids(features, labels) - binary_ranking = self._compute_binary_ranking(features, labels, k=1001, remove_top1=True) + binary_ranking = self._compute_binary_ranking( + features, labels, k=1001, remove_top1=True + ) for i, metric_fn in enumerate(self.metric_funcs): metric_name = "{}".format(metric_fn.__name__) @@ -54,7 +72,8 @@ def _build_index(self): model_output_dim = self.model.output_shape[-1] if model_output_dim is None: raise ValueError( - "Can not determine feature dimension from model output shape. Provide the 'feature_dim' argument.") + "Can not determine feature dimension from model output shape. Provide the 'feature_dim' argument." + ) self.feature_dim = model_output_dim INDEX_KEY = "IDMap,Flat" @@ -64,4 +83,175 @@ def _build_index(self): gpu_resources = faiss.StandardGpuResources() index = faiss.index_cpu_to_gpu(gpu_resources, 0, index) - return index \ No newline at end of file + return index + + +class PairedRankingMetricCallback(tf.keras.callbacks.Callback): + def __init__( + self, + model, + dataset: tf.data.Dataset, + metric_funcs, + encoder=None, + batch_size=10, + dataset_len=None, + remove_top1=False, + verbose=True, + name="ranking_metrics", + ): + super().__init__() + self.model = model + self.encoder = encoder + self.dataset = dataset + self.metric_funcs = metric_funcs + self.batch_size = batch_size + self.dataset_len = dataset_len + self.remove_top1 = remove_top1 + self.verbose = verbose + self.name = name + self._supports_tf_logs = True + + self.model = set_predict_return_y(model) + if self.encoder is not None: + self.encoder = set_predict_return_y(encoder) + + def set_model(self, model): + if self.model is None: + self.model = set_predict_return_y(model) + + def on_epoch_end(self, epoch, logs=None): + if self.encoder is not None: + qz, yq = self.encoder.predict(self.dataset) + nq = len(qz) + else: + qz = self.dataset + yq = None + nq = self.dataset_len + + z, y = batch_predict_pairs( + model=self.model, + q=qz, + bq=self.batch_size, + yq=yq, + nq=nq, + verbose=self.verbose, + ) + yqz, ycz = y + y = tf.cast(tf.equal(yqz, tf.transpose(ycz)), tf.int32) + + binary_ranking, index_ranking = rank_labels(y, z, remove_top1=self.remove_top1) + + for i, metric_fn in enumerate(self.metric_funcs): + metric_name = "{}".format(metric_fn.__name__) + metric_score = metric_fn(binary_ranking) + logs[metric_name] = metric_score + + +class ExperimentCallback(tf.keras.callbacks.Callback): + def __init__( + self, + experiments_dir, + checkpoint_monitor, + checkpoint_mode="max", + tensorboard_update_freq="epoch", + tensorboard_write_graph=True, + config_dump=None, + ): + now_timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + self.experiment_dir = os.path.join(experiments_dir, now_timestamp) + self.log_dir = os.path.join(self.experiment_dir, "logs") + self.model_dir = os.path.join(self.experiment_dir, "model") + self.checkpoint_dir = os.path.join(self.model_dir, "checkpoints") + self.export_dir = os.path.join(self.model_dir, "export") + + self.config_dump = config_dump + + csv_logger = tf.keras.callbacks.CSVLogger( + filename=os.path.join(self.log_dir, "epoch_results.txt") + ) + checkpointer = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join( + self.checkpoint_dir, "{epoch:02d}-{" + checkpoint_monitor + ":.5f}.h5" + ), + monitor=checkpoint_monitor, + mode=checkpoint_mode, + save_weights_only=True, + ) + tensorboard = tf.keras.callbacks.TensorBoard( + log_dir=self.log_dir, + update_freq=tensorboard_update_freq, + profile_batch=0, + write_graph=tensorboard_write_graph, + ) + + callbacks = [csv_logger, checkpointer, tensorboard] + self._callback_list = tf.keras.callbacks.CallbackList( + callbacks=callbacks, add_history=False, add_progbar=False + ) + + def set_params(self, params): + self.params = params + self._callback_list.set_params(params) + + def set_model(self, model): + self.model = model + self._callback_list.set_model(model) + + def on_batch_begin(self, batch, logs=None): + self._callback_list.on_batch_begin(batch, logs) + + def on_batch_end(self, batch, logs=None): + self._callback_list.on_batch_end(batch, logs) + + def on_epoch_begin(self, epoch, logs=None): + self._callback_list.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch, logs=None): + self._callback_list.on_epoch_end(epoch, logs) + + def on_train_batch_begin(self, batch, logs=None): + self._callback_list.on_train_batch_begin(batch, logs) + + def on_train_batch_end(self, batch, logs=None): + self._callback_list.on_train_batch_end(batch, logs) + + def on_test_batch_begin(self, batch, logs=None): + self._callback_list.on_test_batch_begin(batch, logs) + + def on_test_batch_end(self, batch, logs=None): + self._callback_list.on_test_batch_end(batch, logs) + + def on_predict_batch_begin(self, batch, logs=None): + self._callback_list.on_predict_batch_begin(batch, logs) + + def on_predict_batch_end(self, batch, logs=None): + self._callback_list.on_predict_batch_end(batch, logs) + + def on_train_begin(self, logs=None): + os.makedirs(self.experiment_dir, exist_ok=True) + os.makedirs(self.log_dir, exist_ok=True) + os.makedirs(self.checkpoint_dir, exist_ok=True) + os.makedirs(self.export_dir, exist_ok=True) + + if self.config_dump is not None: + with open(os.path.join(self.experiment_dir, "config_dump.json"), "w") as f: + json.dump(self.config_dump, f) + + self.model.save_weights(os.path.join(self.checkpoint_dir, "init.h5")) + self._callback_list.on_train_begin(logs) + + def on_train_end(self, logs=None): + self.model.save(os.path.join(self.export_dir), include_optimizer=True) + self._callback_list.on_train_end(logs) + + def on_test_begin(self, logs=None): + self._callback_list.on_test_begin(logs) + + def on_test_end(self, logs=None): + self._callback_list.on_test_end(logs) + + def on_predict_begin(self, logs=None): + self._callback_list.on_predict_begin(logs) + + def on_predict_end(self, logs=None): + self._callback_list.on_predict_end(logs) diff --git a/chambers/layers/attention.py b/chambers/layers/attention.py index a366b32..f2f3082 100644 --- a/chambers/layers/attention.py +++ b/chambers/layers/attention.py @@ -90,8 +90,9 @@ def __init__( dense_kernel_initializer="glorot_uniform", dropout_rate=0.1, causal=False, + **kwargs ): - super(MultiHeadAttention, self).__init__() + super(MultiHeadAttention, self).__init__(**kwargs) self.head_dim = head_dim self.num_heads = num_heads self.dense_kernel_initializer = dense_kernel_initializer @@ -106,7 +107,8 @@ def __init__( self.permute_mask = tf.keras.layers.Permute((2, 1)) def build(self, input_shape): - (b, _, d) = input_shape[0] + d = input_shape[0][-1] + self.w_query = self.add_weight( name="w_query", shape=(d, self.num_heads, self.head_dim), @@ -209,20 +211,29 @@ def compute_mask(self, inputs, mask=None): return None def get_config(self): + if isinstance(self.dense_kernel_initializer, tf.keras.initializers.Initializer): + dense_kernel_initializer = tf.keras.initializers.serialize( + self.dense_kernel_initializer + ) + else: + dense_kernel_initializer = self.dense_kernel_initializer + config = { "head_dim": self.head_dim, "num_heads": self.num_heads, - "dense_kernel_initializer": tf.keras.initializers.serialize( - self.dense_kernel_initializer - ), + "dense_kernel_initializer": dense_kernel_initializer, "dropout_rate": self.dropout_rate, "causal": self.causal, } base_config = super(MultiHeadAttention, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): - config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( - config["dense_kernel_initializer"] - ) + if isinstance( + config["dense_kernel_initializer"], tf.keras.initializers.Initializer + ): + config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( + config["dense_kernel_initializer"] + ) return cls(**config) diff --git a/chambers/layers/distance.py b/chambers/layers/distance.py index 61d9dea..48dd080 100644 --- a/chambers/layers/distance.py +++ b/chambers/layers/distance.py @@ -1,8 +1,9 @@ +import math import tensorflow as tf class Distance(tf.keras.layers.Layer): - def __init__(self, axis=-1, keepdims=True, **kwargs): + def __init__(self, axis=-1, keepdims=False, **kwargs): super(Distance, self).__init__(**kwargs) self.axis = axis self.keepdims = keepdims @@ -18,61 +19,82 @@ class L1Distance(Distance): """ L1 distance or "Manhattan-distance" layer - This layer takes as input a list of two vectors [v1, v2] and computes - the L1 distance between v1 and v2 according to the following equation: + This layer takes as input a list of two vectors [a, b] and computes + the L1 distance between a and b according to the following equation: - l1 = |v1 - v2| + l1 = |a - b| """ def call(self, inputs, **kwargs): - v1, v2 = inputs - x = v1 - v2 + a, b = inputs + x = a - b x = tf.abs(x) x = tf.reduce_sum(x, axis=self.axis, keepdims=self.keepdims) return x @tf.keras.utils.register_keras_serializable(package="Chambers") -class CosineDistance(Distance): +class L2Distance(Distance): """ - Cosine distance layer - - This layer takes as input a list of two vectors [v1, v2] and computes - the Cosine distance between v1 and v2 according to the following equation: + L2 distance layer. Also knows as Euclidean distance. - cosine similarity = (v1 . v2) / (||v1|| * ||v2||) + This layer takes as input a list of two vectors [a, b] and computes + the Euclidean distance between a and b according to the following equation: - cosine distance = 1 - cosine similarity + euclidean distance = sqrt((a - b) . (a - b)) """ def call(self, inputs, **kwargs): - v1, v2 = inputs - v1 = tf.nn.l2_normalize(v1, axis=self.axis) - v2 = tf.nn.l2_normalize(v2, axis=self.axis) - x = v1 * v2 + a, b = inputs + x = a - b + x = tf.square(x) x = tf.reduce_sum(x, axis=self.axis, keepdims=self.keepdims) - x = 1 - x + x = tf.sqrt(x) return x @tf.keras.utils.register_keras_serializable(package="Chambers") -class L2Distance(Distance): +class CosineSimilarity(Distance): """ - L2 distance layer. Also knows as Euclidean distance. + Cosine distance layer + + This layer takes as input a list of two vectors [a, b] and computes + the cosine similarity between a and b according to the following equation: - This layer takes as input a list of two vectors [v1, v2] and computes - the Euclidean distance between v1 and v2 according to the following equation: + cosine similarity = (a . b) / (||a|| * ||b||) - euclidean distance = sqrt((v1 - v2) . (v1 - v2)) + scaled cosine similarity = (cosine similairty + 1) / 2 """ def call(self, inputs, **kwargs): - v1, v2 = inputs - x = v1 - v2 - x = tf.square(x) + a, b = inputs + x = self._cosine_similarity(a, b) + return self._scale(x) + + def _cosine_similarity(self, a, b): + a = tf.nn.l2_normalize(a, axis=self.axis) + b = tf.nn.l2_normalize(b, axis=self.axis) + x = a * b x = tf.reduce_sum(x, axis=self.axis, keepdims=self.keepdims) - x = tf.sqrt(x) return x + + def _scale(self, cos_sim): + return (cos_sim + 1) / 2 + + +class AngularCosineSimilarity(CosineSimilarity): + def _scale(self, cos_sim): + return 1 - tf.math.acos(cos_sim) / math.pi + + +class CubicCosineSimilarity(CosineSimilarity): + def _scale(self, cos_sim): + return 0.5 + 0.25 * cos_sim + 0.25 * tf.pow(cos_sim, 3) + + +class SqrtCosineSimilarity(CosineSimilarity): + def _scale(self, cos_sim): + return 1 - tf.sqrt((1 - cos_sim) / 2) diff --git a/chambers/layers/embedding.py b/chambers/layers/embedding.py index 01bf9aa..4651886 100644 --- a/chambers/layers/embedding.py +++ b/chambers/layers/embedding.py @@ -31,7 +31,7 @@ def call(self, inputs, mask=None, **kwargs): x = self.positional_encoding(sequence_len, self.embedding_dim) if self.add_to_input: - x = inputs + x + x = inputs + tf.cast(x, inputs.dtype) return x @@ -89,7 +89,7 @@ def __init__( scale=None, eps=1e-6, add_to_input=True, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.embedding_dim = embedding_dim @@ -122,7 +122,7 @@ def call(self, inputs, mask=None, **kwargs): x = self.compute_positional_mask(ones) if self.add_to_input: - x = inputs + x + x = inputs + tf.cast(x, inputs.dtype) return x @@ -176,11 +176,12 @@ def __init__( dtype=None, add_to_input=True, name="learned_embedding", + **kwargs, ): self.initializer = initializer self.add_to_input = add_to_input self.supports_masking = True - super(LearnedEmbedding1D, self).__init__(dtype=dtype, name=name) + super(LearnedEmbedding1D, self).__init__(dtype=dtype, name=name, **kwargs) def build(self, input_shape): self.embedding = self.add_weight( @@ -209,9 +210,12 @@ def get_config(self): base_config = super(LearnedEmbedding1D, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance(config["initializer"], tf.keras.initializers.Initializer): - config["initializer"] = tf.keras.initializers.deserialize(config["initializer"]) + config["initializer"] = tf.keras.initializers.deserialize( + config["initializer"] + ) return cls(**config) @@ -237,6 +241,7 @@ def __init__( initializer=None, dtype=None, name="concat_embedding", + **kwargs, ): assert ( side == "left" or side == "right" @@ -248,7 +253,7 @@ def __init__( self.side = side self.initializer = initializer self.concat = tf.keras.layers.Concatenate(axis=axis) - super(ConcatEmbedding, self).__init__(dtype=dtype, name=name) + super(ConcatEmbedding, self).__init__(dtype=dtype, name=name, **kwargs) def build(self, input_shape): self.embedding = self.add_weight( @@ -287,8 +292,11 @@ def get_config(self): base_config = super(ConcatEmbedding, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance(config["initializer"], tf.keras.initializers.Initializer): - config["initializer"] = tf.keras.initializers.deserialize(config["initializer"]) + config["initializer"] = tf.keras.initializers.deserialize( + config["initializer"] + ) return cls(**config) diff --git a/chambers/layers/transformer.py b/chambers/layers/transformer.py index ba95c5e..280d052 100644 --- a/chambers/layers/transformer.py +++ b/chambers/layers/transformer.py @@ -16,8 +16,9 @@ def __init__( dense_dropout_rate=0.1, norm_epsilon=1e-6, pre_norm=False, + **kwargs, ): - super(EncoderLayer, self).__init__() + super(EncoderLayer, self).__init__(**kwargs) self.embed_dim = embed_dim self.num_heads = num_heads self.ff_dim = ff_dim @@ -96,6 +97,7 @@ def get_config(self): base_config = super(EncoderLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance( config["dense_kernel_initializer"], tf.keras.initializers.Initializer @@ -120,8 +122,9 @@ def __init__( norm_epsilon=1e-6, pre_norm=False, causal=True, + **kwargs, ): - super(DecoderLayer, self).__init__() + super(DecoderLayer, self).__init__(**kwargs) self.embed_dim = embed_dim # TODO: get embed_dim from inputs_shape passed to build and remove this argument. self.num_heads = num_heads self.ff_dim = ff_dim @@ -237,6 +240,7 @@ def get_config(self): base_config = super(DecoderLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance( config["dense_kernel_initializer"], tf.keras.initializers.Initializer @@ -262,7 +266,7 @@ def __init__( norm_epsilon=1e-6, pre_norm=False, norm_output=False, - **kwargs + **kwargs, ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -332,6 +336,7 @@ def get_config(self): base_config = super(Encoder, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance( config["dense_kernel_initializer"], tf.keras.initializers.Initializer @@ -359,7 +364,7 @@ def __init__( norm_output=False, causal=True, return_sequence=False, - **kwargs + **kwargs, ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -397,6 +402,7 @@ def build(self, input_shape): ) for i in range(self.num_layers) ] + super(Decoder, self).build(input_shape) def call(self, inputs, mask=None, training=None, **kwargs): x, x_encoder = inputs @@ -449,9 +455,10 @@ def get_config(self): "causal": self.causal, "return_sequence": self.return_sequence, } - base_config = super(EncoderLayer, self).get_config() + base_config = super(Decoder, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod def from_config(cls, config): if isinstance( config["dense_kernel_initializer"], tf.keras.initializers.Initializer diff --git a/chambers/losses/metric_learning.py b/chambers/losses/metric_learning.py index d478c4c..72ea209 100644 --- a/chambers/losses/metric_learning.py +++ b/chambers/losses/metric_learning.py @@ -6,13 +6,13 @@ from chambers.miners import MultiSimilarityMiner as _MSMiner -class PairBasedLoss(tf.keras.losses.Loss, abc.ABC): +class PairLoss(tf.keras.losses.Loss, abc.ABC): def __init__( self, ignore_diag=True, ignore_negative_labels=True, miner=None, - name="pair_based_loss", + name=None, **kwargs, ): """ @@ -59,6 +59,12 @@ def compute_similarity_matrix(self, y_pred: tf.Tensor) -> tf.Tensor: """ return tf.matmul(y_pred, y_pred, transpose_b=True) + def compute_signed_masks(self, y_true): + y_true = tf.reshape(y_true, [-1, 1]) + pos_mask = tf.equal(y_true, tf.transpose(y_true)) + neg_mask = tf.logical_not(pos_mask) + return pos_mask, neg_mask + def get_signed_pairs( self, similarity_matrix: tf.Tensor, y_true: tf.Tensor ) -> Tuple[tf.RaggedTensor, tf.RaggedTensor]: @@ -69,34 +75,26 @@ def get_signed_pairs( :param y_true: The class labels for the embeddings as a Tensor with shape [n]. :return: Positive pairs and negative pairs as a tuple of 2D RaggedTensors each with shape [n, 0... n]. """ - y_true = tf.reshape(y_true, [-1, 1]) - pos_pair_mask = tf.cast(tf.equal(y_true, tf.transpose(y_true)), tf.uint8) - neg_pair_mask = 1 - pos_pair_mask + pos_mask, neg_pair_mask = self.compute_signed_masks(y_true) if self.ignore_negative_labels: - not_triplet_neg = tf.cast(tf.greater_equal(y_true, 0), tf.uint8) - pos_pair_mask = pos_pair_mask * not_triplet_neg - neg_pair_mask = neg_pair_mask * not_triplet_neg + not_triplet_neg = tf.greater_equal(y_true, 0) + pos_mask = pos_mask & not_triplet_neg + neg_pair_mask = neg_pair_mask & not_triplet_neg if self.ignore_diag: # ignore mirror pairs nrows = tf.shape(similarity_matrix)[0] ncols = tf.shape(similarity_matrix)[1] - inverse_eye = 1 - tf.eye(nrows, ncols, dtype=tf.uint8) - pos_pair_mask = pos_pair_mask * inverse_eye - neg_pair_mask = neg_pair_mask * inverse_eye + inverse_eye = tf.logical_not(tf.eye(nrows, ncols, dtype=tf.bool)) + pos_mask = pos_mask & inverse_eye + neg_pair_mask = neg_pair_mask & inverse_eye # get similarities of positive pairs - pos_mat = tf.RaggedTensor.from_row_lengths( - values=similarity_matrix[tf.cast(pos_pair_mask, tf.bool)], - row_lengths=tf.cast(tf.reduce_sum(pos_pair_mask, axis=1), tf.int32), - ) + pos_mat = tf.ragged.boolean_mask(similarity_matrix, pos_mask) # get similarities of negative pairs - neg_mat = tf.RaggedTensor.from_row_lengths( - values=similarity_matrix[tf.cast(neg_pair_mask, tf.bool)], - row_lengths=tf.cast(tf.reduce_sum(neg_pair_mask, axis=1), tf.int32), - ) + neg_mat = tf.ragged.boolean_mask(similarity_matrix, neg_pair_mask) return pos_mat, neg_mat @@ -113,8 +111,75 @@ def compute_loss( pass +class PairMatrixLoss(PairLoss): + def compute_similarity_matrix(self, y_pred: tf.Tensor) -> tf.Tensor: + return y_pred + + def compute_signed_masks(self, y_true): + pos_mask = tf.cast(y_true, tf.bool) + neg_mask = tf.logical_not(pos_mask) + return pos_mask, neg_mask + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class MultiSimilarityLoss(PairLoss): + """ + Multi-similarity loss + + References: + + [1] Wang, Xun et al. “Multi-Similarity Loss With General Pair Weighting for Deep Metric Learning.” + 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2019): 5017-5025. + https://arxiv.org/abs/1904.06627 + """ + + def __init__( + self, + pos_scale=2.0, + neg_scale=40.0, + threshold=0.5, + ignore_diag=True, + ignore_negative_labels=True, + miner=_MSMiner(margin=0.1), + name="multi_similarity_loss", + **kwargs, + ): + super().__init__( + ignore_diag=ignore_diag, + ignore_negative_labels=ignore_negative_labels, + miner=miner, + name=name, + **kwargs, + ) + self.pos_scale = pos_scale # alpha + self.neg_scale = neg_scale # beta + self.threshold = threshold # lambda + + def compute_loss(self, positive_pairs, negative_pairs): + pos_loss = ( + tf.math.log( + 1 + + tf.reduce_sum( + tf.exp(-self.pos_scale * (positive_pairs - self.threshold)), axis=1 + ) + ) + / self.pos_scale + ) + neg_loss = ( + tf.math.log( + 1 + + tf.reduce_sum( + tf.exp(self.neg_scale * (negative_pairs - self.threshold)), axis=1 + ) + ) + / self.neg_scale + ) + + return pos_loss + neg_loss + + @tf.keras.utils.register_keras_serializable(package="Chambers") -class MultiSimilarityLoss(PairBasedLoss): +class MultiSimilarityLossMatrix(PairMatrixLoss): """ Multi-similarity loss @@ -171,7 +236,7 @@ def compute_loss(self, positive_pairs, negative_pairs): @tf.keras.utils.register_keras_serializable(package="Chambers") -class ContrastiveLoss(PairBasedLoss): +class ContrastiveLoss(PairLoss): def __init__( self, positive_margin=1.0, diff --git a/chambers/metrics_ranking.py b/chambers/metrics_ranking.py index e4962b0..db36e71 100644 --- a/chambers/metrics_ranking.py +++ b/chambers/metrics_ranking.py @@ -82,6 +82,8 @@ def precision_at_k(binary_ranking, k: int): def mean_average_precision(binary_ranking, k: int = None): + binary_ranking = tf.cast(binary_ranking, tf.float32) + if k is None or k > tf.shape(binary_ranking)[1]: k = tf.shape(binary_ranking)[1] diff --git a/chambers/models/backbones/vision_transformer.py b/chambers/models/backbones/vision_transformer.py index 2de8321..848b95b 100644 --- a/chambers/models/backbones/vision_transformer.py +++ b/chambers/models/backbones/vision_transformer.py @@ -5,6 +5,7 @@ from chambers.augmentations import ImageNetNormalization from chambers.layers.embedding import ConcatEmbedding, LearnedEmbedding1D +from chambers.layers.reduce import Sum from chambers.layers.transformer import Encoder from chambers.utils.layer_utils import inputs_to_input_layer @@ -113,7 +114,9 @@ def _get_model_info(weights, model_name): return default_size, has_feature -def _obtain_inputs(input_tensor, input_shape, default_size, min_size, weights, model_name, name=None): +def _obtain_inputs( + input_tensor, input_shape, default_size, min_size, weights, model_name, name=None +): if input_shape is not None and _are_weights_pretrained(weights, model_name): default_shape = (default_size, default_size, input_shape[-1]) if tuple(input_shape) != default_shape: @@ -123,6 +126,7 @@ def _obtain_inputs(input_tensor, input_shape, default_size, min_size, weights, m ) ) + # TODO: write own `obtain_input_shape` function input_shape = imagenet_utils.obtain_input_shape( input_shape=input_shape, default_size=default_size, @@ -165,6 +169,28 @@ def _load_weights(model, weights, include_top): model.load_weights(weights) +def _pool(x, method=None, prefix=""): + if method == "avg": + x = tf.keras.layers.Cropping1D((1, 0), name=prefix + "sequence_embeddings")(x) + x = tf.keras.layers.GlobalAveragePooling1D(name=prefix + "avg_pool")(x) + elif method == "max": + x = tf.keras.layers.Cropping1D((1, 0), name=prefix + "sequence_embeddings")(x) + x = tf.keras.layers.GlobalMaxPooling1D(name=prefix + "max_pool")(x) + elif method == "sum": + x = tf.keras.layers.Cropping1D((1, 0), name=prefix + "sequence_embeddings")(x) + x = Sum(axis=1, name=prefix + "sum_pool")(x) + elif method == "cls": + x = tf.keras.Sequential( + [ + tf.keras.layers.Cropping1D((0, x.shape[1] - 1)), + tf.keras.layers.Reshape([-1]), + ], + name=prefix + "cls_embedding", + )(x) + + return x + + def VisionTransformer( patch_size, patch_dim, @@ -176,7 +202,7 @@ def VisionTransformer( input_shape=None, include_top=True, weights="imagenet21k+_224", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -244,20 +270,7 @@ def VisionTransformer( norm_output=True, )(x) - if pooling == "avg": - x = tf.keras.layers.Cropping1D((1, 0), name="sequence_embeddings")(x) - x = tf.keras.layers.GlobalAveragePooling1D(name="avg_pool")(x) - elif pooling == "max": - x = tf.keras.layers.Cropping1D((1, 0), name="sequence_embeddings")(x) - x = tf.keras.layers.GlobalMaxPooling1D(name="max_pool")(x) - else: - x = tf.keras.Sequential( - [ - tf.keras.layers.Cropping1D((0, x.shape[1] - 1)), - tf.keras.layers.Reshape([-1]), - ], - name="cls_embedding", - )(x) + x = _pool(x, method=pooling) if feature_dim is not None: x = tf.keras.layers.Dense(units=feature_dim, activation="tanh", name="feature")( @@ -353,20 +366,7 @@ def DistilledVisionTransformer( norm_output=True, )(x) - if pooling == "avg": - x_cls = tf.keras.layers.Cropping1D((2, 0), name="sequence_embeddings")(x) - x_cls = tf.keras.layers.GlobalAveragePooling1D(name="avg_pool")(x_cls) - elif pooling == "max": - x_cls = tf.keras.layers.Cropping1D((2, 0), name="sequence_embeddings")(x) - x_cls = tf.keras.layers.GlobalMaxPooling1D(name="max_pool")(x_cls) - else: - x_cls = tf.keras.Sequential( - [ - tf.keras.layers.Cropping1D((0, x.shape[1] - 1)), - tf.keras.layers.Reshape([-1]), - ], - name="cls_embedding", - )(x) + x_cls = _pool(x, method=pooling) x_dist = tf.keras.Sequential( [ @@ -392,6 +392,7 @@ def DistilledVisionTransformer( else: x = tf.keras.layers.Average()([x_cls, x_dist]) + x = tf.keras.layers.Activation("linear", dtype=tf.float32, name="cast_float32")(x) model = tf.keras.models.Model(inputs=inputs, outputs=x, name=model_name) _load_weights(model, weights, include_top) @@ -404,7 +405,7 @@ def ViTS16( input_shape=None, include_top=True, weights="imagenet_224_deit", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -440,7 +441,7 @@ def ViTB16( input_shape=None, include_top=True, weights="imagenet21k+_224", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -476,7 +477,7 @@ def ViTB32( input_shape=None, include_top=True, weights="imagenet21k+_384", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -512,7 +513,7 @@ def ViTL16( input_shape=None, include_top=True, weights="imagenet21k+_224", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -548,7 +549,7 @@ def ViTL32( input_shape=None, include_top=True, weights="imagenet21k+_384", - pooling=None, + pooling="cls", feature_dim=None, classes=1000, classifier_activation=None, @@ -585,7 +586,7 @@ def DeiTS16( input_shape=None, include_top=True, weights="imagenet_224", - pooling=None, + pooling="cls", classes=1000, classifier_activation=None, ): @@ -621,7 +622,7 @@ def DeiTB16( input_shape=None, include_top=True, weights="imagenet_224", - pooling=None, + pooling="cls", classes=1000, classifier_activation=None, ): diff --git a/chambers/models/base.py b/chambers/models/base.py new file mode 100644 index 0000000..ba00b88 --- /dev/null +++ b/chambers/models/base.py @@ -0,0 +1,80 @@ +import tensorflow as tf +from tensorflow.python.keras.engine import data_adapter +import types + + +# def MakeClassFromInstance(instance): +# from copy import deepcopy +# copy = deepcopy(instance.__dict__) +# InstanceFactory = type('InstanceFactory', (instance.__class__,), {}) +# InstanceFactory.__init__ = lambda self, *args, **kwargs: self.__dict__.update(copy) +# return InstanceFactory + + +class BaseModel(tf.keras.Model): + @classmethod + def from_model(cls, model): + return cls(inputs=model.inputs, outputs=model.outputs, name=model.name) + + +class PredictReturnYModel(BaseModel): + def predict_step(self, data): + """The logic for one inference step. + + This method can be overridden to support custom inference logic. + This method is called by `Model.make_predict_function`. + + This method should contain the mathematical logic for one step of inference. + This typically includes the forward pass. + + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model.make_predict_function`, which can also be overridden. + + Arguments: + data: A nested structure of `Tensor`s. + + Returns: + The result of one inference step, typically the output of calling the + `Model` on data. + """ + data = data_adapter.expand_1d(data) + x, y, _ = data_adapter.unpack_x_y_sample_weight(data) + + if y is None: + return self(x, training=False) + + return self(x, training=False), y + + +def set_predict_return_y(model): + def predict_step(self, data): + """The logic for one inference step. + + This method can be overridden to support custom inference logic. + This method is called by `Model.make_predict_function`. + + This method should contain the mathematical logic for one step of inference. + This typically includes the forward pass. + + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model.make_predict_function`, which can also be overridden. + + Arguments: + data: A nested structure of `Tensor`s. + + Returns: + The result of one inference step, typically the output of calling the + `Model` on data. + """ + data = data_adapter.expand_1d(data) + x, y, _ = data_adapter.unpack_x_y_sample_weight(data) + + if y is None: + return self(x, training=False) + + return self(x, training=False), y + + model.predict_step = types.MethodType(predict_step, model) + return model diff --git a/chambers/models/bloodhound.py b/chambers/models/bloodhound.py new file mode 100644 index 0000000..610f4df --- /dev/null +++ b/chambers/models/bloodhound.py @@ -0,0 +1,492 @@ +import tensorflow as tf +from tensorflow.python.keras.utils import layer_utils + +from chambers.layers.attention import MultiHeadAttention +from chambers.layers.distance import CosineSimilarity +from chambers.layers.embedding import PositionalEmbedding2D +from chambers.layers.transformer import Decoder, DecoderLayer +from chambers.models import backbones +from chambers.models.backbones.vision_transformer import _obtain_inputs +from chambers.utils.generic import ProgressBar + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class _Pool3DAxis1(tf.keras.layers.Layer): + def __init__(self, method="avg", keepdims=False, name=None, **kwargs): + super(_Pool3DAxis1, self).__init__(name=name, **kwargs) + if method not in {"avg", "max", "sum", "cls"}: + raise ValueError("`method` must be either 'avg', 'max', 'sum' or 'cls'.") + + self.method = method + self.keepdims = keepdims + self.axis = 1 + + def call(self, inputs, **kwargs): + x = self._slice(inputs) + + if self.method == "avg": + x = tf.reduce_mean(x, axis=self.axis, keepdims=self.keepdims) + elif self.method == "max": + x = tf.reduce_max(x, axis=self.axis, keepdims=self.keepdims) + elif self.method == "sum": + x = tf.reduce_sum(x, axis=self.axis, keepdims=self.keepdims) + + return x + + def _slice(self, x): + if self.method == "cls": + x = x[:, 0, :] + else: + x = x[:, 1:, :] + return x + + def get_config(self): + config = { + "method": self.method, + "keepdims": self.keepdims, + } + base_config = super(_Pool3DAxis1, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class _Pool4DAxis2(_Pool3DAxis1): + def __init__(self, *args, **kwargs): + super(_Pool4DAxis2, self).__init__(*args, **kwargs) + self.axis = 2 + + def _slice(self, x): + if self.method == "cls": + x = x[:, :, 0, :] + else: + x = x[:, :, 1:, :] + return x + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class ExpandDims(tf.keras.layers.Layer): + def __init__(self, axis=0, **kwargs): + super(ExpandDims, self).__init__(**kwargs) + self.axis = axis + + def call(self, inputs, **kwargs): + return tf.expand_dims(inputs, self.axis) + + def get_config(self): + config = { + "axis": self.axis, + } + base_config = super(ExpandDims, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class MultiHeadAttention4D(MultiHeadAttention): + def call(self, inputs, mask=None, training=None): + """ + Einsum notation: + b = batch_size + e = expanded dimension + t = sequence length + d = embedding dimension + n = num heads + h = head dimension + """ + q = inputs[0] # [b, e, tq, d] + v = inputs[1] # [e, b, tv, d] + k = inputs[2] if len(inputs) > 2 else v # [e, b, tv, d] + + # linear projections + head split + query = tf.einsum("ebtd,dnh->ebnth", q, self.w_query) + self.b_query + value = tf.einsum("betd,dnh->benth", v, self.w_value) + self.b_value + key = tf.einsum("betd,dnh->benth", k, self.w_key) + self.b_key + + # TODO: Mask + + # TODO: Make attention attend across query batch (b) dimension for each candidate + attention = self.attention([query, value, key], mask=mask, training=training) + + # linear projection + head merge + x = ( + tf.einsum("benth,ndh->betd", attention, self.w_projection) + + self.b_projection + ) + + return x + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class DecoderLayer4D(DecoderLayer): + def __init__(self, *args, **kwargs): + super(DecoderLayer4D, self).__init__(*args, **kwargs) + + # self-attention + self.multi_head_attention1 = MultiHeadAttention4D( + head_dim=self.embed_dim // self.num_heads, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + dropout_rate=self.attention_dropout_rate, + causal=self.causal, + ) + + # cross-attention + self.multi_head_attention2 = MultiHeadAttention4D( + head_dim=self.embed_dim // self.num_heads, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + dropout_rate=self.attention_dropout_rate, + causal=False, + ) + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class Decoder4D(Decoder): + def build(self, input_shape): + self.layers = [ + DecoderLayer4D( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + ff_dim=self.ff_dim, + dense_kernel_initializer=self.dense_kernel_initializer, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=self.dense_dropout_rate, + norm_epsilon=self.norm_epsilon, + pre_norm=self.pre_norm, + causal=self.causal, + ) + for i in range(self.num_layers) + ] + + +def BloodhoundFunctional( + query_shape, + candidates_shape, + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + include_top=True, + pooling=None, + name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + + inputs_q = tf.keras.layers.Input(shape=query_shape) + inputs_c = tf.keras.layers.Input(shape=candidates_shape) + + # NOTE: positional embedding here if encoder is not a transformer + q = ExpandDims(axis=1, name="q_expand")(inputs_q) + c = ExpandDims(axis=0, name="c_expand")(inputs_c) + c = Decoder4D( + embed_dim=c.shape[-1], + num_heads=n_heads, + ff_dim=ff_dim, + num_layers=n_layers, + attention_dropout_rate=dropout_rate, + dense_dropout_rate=dropout_rate, + norm_output=True, + pre_norm=False, + causal=False, + name="decoder", + )([c, q]) + + if pooling is not None: + q = _Pool4DAxis2(method=pooling, name="pool_q")(q) + c = _Pool4DAxis2(method=pooling, name="pool_c")(c) + + if include_top: + if pooling is None: + raise ValueError( + "`include_top=True` requires `pooling` to be either 'avg', 'max', 'sum', or 'cls'." + ) + + x = CosineSimilarity(axis=-1)([q, c]) + # x = tf.keras.layers.Dense(1, activation="sigmoid", dtype=tf.float32)(c) + # x = tf.keras.layers.Lambda(lambda x: x[:, :, 0])(x) + else: + x = [q, c] + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=name) + return model + + +def Bloodhound4D( + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + query_tensor=None, + query_shape=None, + candidates_tensor=None, + candidates_shape=None, + include_top=True, + weights="imagenet21k+_224", + pooling=None, + model_name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + + inputs_q = _obtain_inputs( + query_tensor, + query_shape, + default_size=224, + min_size=16, + weights=weights, + model_name=model_name, + name="query", + ) + inputs_c = _obtain_inputs( + candidates_tensor, + candidates_shape, + default_size=224, + min_size=16, + weights=weights, + model_name=model_name, + name="candidates", + ) + + enc = backbones.ViTB16( + weights="imagenet21k+_224", + pooling=None, + include_top=False, + ) + enc = tf.keras.Model(enc.inputs, enc.outputs, name="encoder") + + q_enc = enc(inputs_q) + c_enc = enc(inputs_c) + # NOTE: positional embedding here if encoder is not a transformer + dec = BloodhoundFunctional( + query_shape=q_enc.shape[1:], + candidates_shape=c_enc.shape[1:], + n_layers=n_layers, + n_heads=n_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + include_top=include_top, + pooling=pooling, + name="decoder", + ) + x = dec([q_enc, c_enc]) + + x = tf.keras.layers.Activation("linear", dtype=tf.float32, name="cast_float32")(x) + + if query_tensor is not None: + inputs_q = layer_utils.get_source_inputs(query_tensor) + if candidates_tensor is not None: + inputs_c = layer_utils.get_source_inputs(candidates_tensor) + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=model_name) + return model + + +def BloodhoundRes( + embed_dim, + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + query_tensor=None, + query_shape=None, + candidates_tensor=None, + candidates_shape=None, + include_top=True, + weights="imagenet21k+_224", + pooling=None, + model_name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + enc = tf.keras.applications.ResNet50( + weights="imagenet", include_top=False, input_shape=(224, 224, 3) + ) + enc = tf.keras.Model( + inputs=enc.inputs, + outputs=enc.get_layer("conv4_block6_out").output, + name="encoder", + ) + + inputs_q = _obtain_inputs( + query_tensor, + query_shape, + default_size=224, + min_size=32, + weights=weights, + model_name=model_name, + name="query", + ) + inputs_c = _obtain_inputs( + candidates_tensor, + candidates_shape, + default_size=224, + min_size=32, + weights=weights, + model_name=model_name, + name="candidates", + ) + + q = enc(inputs_q) + c = enc(inputs_c) + + proj = tf.keras.layers.Conv2D(filters=embed_dim, kernel_size=1) + q = proj(q) + c = proj(c) + q = PositionalEmbedding2D(embedding_dim=embed_dim, add_to_input=True)(q) + c = PositionalEmbedding2D(embedding_dim=embed_dim, add_to_input=True)(c) + q = tf.keras.layers.Reshape([q.shape[1] * q.shape[2], q.shape[3]])(q) + c = tf.keras.layers.Reshape([c.shape[1] * c.shape[2], c.shape[3]])(c) + + # NOTE: positional embedding here if encoder is not a transformer + dec = BloodhoundFunctional( + query_shape=q.shape[1:], + candidates_shape=c.shape[1:], + n_layers=n_layers, + n_heads=n_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + include_top=include_top, + pooling=pooling, + name="decoder", + ) + x = dec([q, c]) + + x = tf.keras.layers.Activation("linear", dtype=tf.float32, name="cast_float32")(x) + + if query_tensor is not None: + inputs_q = layer_utils.get_source_inputs(query_tensor) + if candidates_tensor is not None: + inputs_c = layer_utils.get_source_inputs(candidates_tensor) + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=model_name) + return model + +#%% +def valid_cardinality(dataset): + if dataset.cardinality() == tf.data.INFINITE_CARDINALITY: + return False + if dataset.cardinality() == tf.data.UNKNOWN_CARDINALITY: + return False + return True + + +def _to_dataset(x, y=None, n=None): + if not isinstance(x, tf.data.Dataset): + n = tf.shape(x)[0] + if y is not None: + x = tf.data.Dataset.from_tensor_slices((x, y)) + else: + x = tf.data.Dataset.from_tensor_slices(x) + else: + if valid_cardinality(x): + n = x.cardinality() + elif not valid_cardinality(x) and n is None: + raise ValueError("Unable to infer length of dataset {}.".format(x)) + + return x, n + + +def pair_iteration_dataset(q, c, bq, bc, yq=None, yc=None, nq=None, nc=None): + qd, nq = _to_dataset(q, yq, nq) + cd, nc = _to_dataset(c, yc, nc) + with_labels = not isinstance(qd.element_spec, tf.TensorSpec) + + bq = tf.cast(bq, tf.int64) + bc = tf.cast(bc, tf.int64) + nq = tf.cast(nq, tf.int64) if nq is not None else nq + nc = tf.cast(nc, tf.int64) if nc is not None else nc + + qd = qd.batch(bq) + cd = cd.batch(bc) + + nqb = tf.cast(tf.math.ceil(nq / bq), tf.int64) + ncb = tf.cast(tf.math.ceil(nc / bc), tf.int64) + + if with_labels: + repeat_batch = lambda x, y: tf.data.Dataset.from_tensors((x, y)).repeat(ncb) + else: + repeat_batch = lambda x: tf.data.Dataset.from_tensors(x).repeat(ncb) + + qd = qd.flat_map(repeat_batch) + cd = cd.repeat(nqb) + + if with_labels: + td = tf.data.Dataset.zip((qd, cd)) + # ((x_q, x_c), (y_q, y_c)) + td = td.map(lambda q, c: ((q[0], c[0]), (q[1], c[1]))) + else: + td = tf.data.Dataset.zip(((qd, cd),)) + + return td + + +def reshape_pair_predictions(x, bq, bc, nq, nc, y=None): + nqb = tf.cast(tf.math.ceil(nq / bq), tf.int64) + ncb = tf.cast(tf.math.ceil(nc / bc), tf.int64) + + x = tf.reshape(x, [nqb, ncb, bq, bc]) + x = tf.transpose(x, [0, 2, 1, 3]) # [nqb, bq, ncb, bc] + x = tf.reshape(x, [nq, nc]) + + if y is not None: + yq, yc = y + yq = tf.reshape(yq, [nqb, ncb, bq])[:, 0] + yq = tf.reshape(yq, [-1, 1]) + yc = yc[:nc] + return x, (yq, yc) + + return x + + +def batch_predict_pairs( + model, q, bq, c=None, bc=None, yq=None, yc=None, nq=None, nc=None, verbose=True +): + if c is None: + c = q + bc = bq + yc = yq + nc = nq + elif bc is None: + bc = bq + + q, nq = _to_dataset(q, yq, nq) + c, nc = _to_dataset(c, yc, nc) + + bq = tf.cast(bq, tf.int64) + bc = tf.cast(bc, tf.int64) + nq = tf.cast(nq, tf.int64) if nq is not None else nq + nc = tf.cast(nc, tf.int64) if nc is not None else nc + + bq = tf.minimum(bq, nq) + bc = tf.minimum(bc, nc) + + td = pair_iteration_dataset(q, c, bq, bc, yq, yc, nq, nc) + + if verbose: + nqb = tf.cast(tf.math.ceil(nq / bq), tf.int32) + ncb = tf.cast(tf.math.ceil(nc / bc), tf.int32) + + prog = ProgressBar(total=nqb * ncb) + td = td.apply(prog.dataset_apply_fn) + + z = model.predict(td) + + if isinstance(z, tuple): + z, y = z + z, y = reshape_pair_predictions(z, bq, bc, nq, nc, y=y) + return z, y + + z = reshape_pair_predictions(z, bq, bc, nq, nc) + return z diff --git a/chambers/models/bloodhound_fusion.py b/chambers/models/bloodhound_fusion.py new file mode 100644 index 0000000..16ab598 --- /dev/null +++ b/chambers/models/bloodhound_fusion.py @@ -0,0 +1,558 @@ +import tensorflow as tf +from tensorflow.python.keras.utils import layer_utils + +from chambers.activations import gelu +from chambers.layers.distance import CosineSimilarity +from chambers.layers.embedding import PositionalEmbedding2D +from chambers.models import backbones +from chambers.models.backbones.vision_transformer import _obtain_inputs +from chambers.models.bloodhound import _Pool4DAxis2, ExpandDims, MultiHeadAttention4D + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class ECA(tf.keras.layers.Layer): + def __init__( + self, + embed_dim=512, + num_heads=8, + dense_kernel_initializer="glorot_uniform", + attention_dropout_rate=0.1, + dense_dropout_rate=0.1, + norm_epsilon=1e-6, + **kwargs, + ): + super(ECA, self).__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dense_kernel_initializer = dense_kernel_initializer + self.attention_dropout_rate = attention_dropout_rate + self.dense_dropout_rate = dense_dropout_rate + self.norm_epsilon = norm_epsilon + + self.multi_head_attention = MultiHeadAttention4D( + head_dim=self.embed_dim // self.num_heads, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + dropout_rate=self.attention_dropout_rate, + causal=False, + ) + self.dropout = tf.keras.layers.Dropout(dense_dropout_rate) + self.norm = tf.keras.layers.LayerNormalization(epsilon=norm_epsilon) + + def call(self, inputs, training=None, **kwargs): + x = self.multi_head_attention([inputs, inputs, inputs], training=training) + x = self.dropout(x, training=training) + x = self.norm(inputs + x) + return x + + def get_config(self): + if isinstance(self.dense_kernel_initializer, tf.keras.initializers.Initializer): + dense_kernel_initializer = tf.keras.initializers.serialize( + self.dense_kernel_initializer + ) + else: + dense_kernel_initializer = self.dense_kernel_initializer + + config = { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "dense_kernel_initializer": dense_kernel_initializer, + "attention_dropout_rate": self.attention_dropout_rate, + "dense_dropout_rate": self.dense_dropout_rate, + "norm_epsilon": self.norm_epsilon, + } + base_config = super(ECA, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if isinstance( + config["dense_kernel_initializer"], tf.keras.initializers.Initializer + ): + config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( + config["dense_kernel_initializer"] + ) + + return cls(**config) + + +class CFA(tf.keras.layers.Layer): + def __init__( + self, + embed_dim=512, + num_heads=8, + ff_dim=2048, + dense_kernel_initializer="glorot_uniform", + attention_dropout_rate=0.1, + dense_dropout_rate=0.1, + norm_epsilon=1e-6, + **kwargs, + ): + super(CFA, self).__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.dense_kernel_initializer = dense_kernel_initializer + self.attention_dropout_rate = attention_dropout_rate + self.dense_dropout_rate = dense_dropout_rate + self.norm_epsilon = norm_epsilon + + self.multi_head_attention = MultiHeadAttention4D( + head_dim=self.embed_dim // self.num_heads, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + dropout_rate=self.attention_dropout_rate, + causal=False, + ) + self.dropout = tf.keras.layers.Dropout(dense_dropout_rate) + self.norm = tf.keras.layers.LayerNormalization(epsilon=norm_epsilon) + + # mlp + self.dense1 = tf.keras.layers.Dense( + ff_dim, activation=gelu, kernel_initializer=dense_kernel_initializer + ) + self.dense2 = tf.keras.layers.Dense( + embed_dim, kernel_initializer=dense_kernel_initializer + ) + self.dropout2 = tf.keras.layers.Dropout(dense_dropout_rate) + self.norm2 = tf.keras.layers.LayerNormalization(epsilon=norm_epsilon) + + def call(self, inputs, training=None, **kwargs): + x, x_kv = inputs + attention = self.multi_head_attention([x, x_kv, x_kv], training=training) + attention = self.dropout(attention, training=training) + x = self.norm(x + attention) + + xd = self.dense1(x) + xd = self.dense2(xd) + xd = self.dropout2(xd) + x = self.norm2(xd + x) + return x + + def get_config(self): + if isinstance(self.dense_kernel_initializer, tf.keras.initializers.Initializer): + dense_kernel_initializer = tf.keras.initializers.serialize( + self.dense_kernel_initializer + ) + else: + dense_kernel_initializer = self.dense_kernel_initializer + + config = { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "dense_kernel_initializer": dense_kernel_initializer, + "attention_dropout_rate": self.attention_dropout_rate, + "dense_dropout_rate": self.dense_dropout_rate, + "norm_epsilon": self.norm_epsilon, + } + base_config = super(CFA, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if isinstance( + config["dense_kernel_initializer"], tf.keras.initializers.Initializer + ): + config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( + config["dense_kernel_initializer"] + ) + + return cls(**config) + + +@tf.keras.utils.register_keras_serializable(package="Chambers") +class FeatureFusionLayer(tf.keras.layers.Layer): + def __init__( + self, + embed_dim=512, + num_heads=8, + ff_dim=2048, + dense_kernel_initializer="glorot_uniform", + attention_dropout_rate=0.1, + dense_dropout_rate=0.1, + norm_epsilon=1e-6, + **kwargs, + ): + super(FeatureFusionLayer, self).__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.dense_kernel_initializer = dense_kernel_initializer + self.attention_dropout_rate = attention_dropout_rate + self.dense_dropout_rate = dense_dropout_rate + self.norm_epsilon = norm_epsilon + + self.eca1 = ECA( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=dense_dropout_rate, + ) + self.eca2 = ECA( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + dense_kernel_initializer=self.dense_kernel_initializer, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=dense_dropout_rate, + ) + self.cfa1 = CFA( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + ff_dim=ff_dim, + dense_kernel_initializer=self.dense_kernel_initializer, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=dense_dropout_rate, + ) + self.cfa2 = CFA( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + ff_dim=ff_dim, + dense_kernel_initializer=self.dense_kernel_initializer, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=dense_dropout_rate, + ) + + def call(self, inputs, training=None, **kwargs): + xq, xc = inputs + + # TODO: Positional embedding + + eca_q = self.eca1(xq, training=training) + eca_c = self.eca2(xc, training=training) + + xq = self.cfa1([eca_q, eca_c]) + xc = self.cfa2([eca_c, eca_q]) + + return xq, xc + + def get_config(self): + if isinstance(self.dense_kernel_initializer, tf.keras.initializers.Initializer): + dense_kernel_initializer = tf.keras.initializers.serialize( + self.dense_kernel_initializer + ) + else: + dense_kernel_initializer = self.dense_kernel_initializer + + config = { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "dense_kernel_initializer": dense_kernel_initializer, + "attention_dropout_rate": self.attention_dropout_rate, + "dense_dropout_rate": self.dense_dropout_rate, + "norm_epsilon": self.norm_epsilon, + } + base_config = super(FeatureFusionLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if isinstance( + config["dense_kernel_initializer"], tf.keras.initializers.Initializer + ): + config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( + config["dense_kernel_initializer"] + ) + + return cls(**config) + + +class FeatureFusionDecoder(tf.keras.layers.Layer): + def __init__( + self, + embed_dim, + num_heads, + ff_dim, + num_layers, + dense_kernel_initializer="glorot_uniform", + attention_dropout_rate=0.1, + dense_dropout_rate=0.1, + norm_epsilon=1e-6, + norm_output=False, + **kwargs, + ): + super(FeatureFusionDecoder, self).__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.num_layers = num_layers + self.dense_kernel_initializer = dense_kernel_initializer + self.attention_dropout_rate = attention_dropout_rate + self.dense_dropout_rate = dense_dropout_rate + self.norm_epsilon = norm_epsilon + self.norm_output = norm_output + + if norm_output: + self.norm_layer = tf.keras.layers.LayerNormalization(epsilon=norm_epsilon) + else: + self.norm_layer = None + self.supports_masking = True + + def build(self, input_shape): + self.layers = [ + FeatureFusionLayer( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + ff_dim=self.ff_dim, + attention_dropout_rate=self.attention_dropout_rate, + dense_dropout_rate=self.dense_dropout_rate, + ) + for i in range(self.num_layers) + ] + super(FeatureFusionDecoder, self).build(input_shape) + + def call(self, inputs, mask=None, training=None, **kwargs): + x, x_kv = inputs + + for layer in self.layers: + x, x_kv = layer([x, x_kv], mask=mask, training=training) + + if self.norm_output: + x = self.norm_layer(x) + x_kv = self.norm_layer(x_kv) + + return x, x_kv + + def get_config(self): + if isinstance(self.dense_kernel_initializer, tf.keras.initializers.Initializer): + dense_kernel_initializer = tf.keras.initializers.serialize( + self.dense_kernel_initializer + ) + else: + dense_kernel_initializer = self.dense_kernel_initializer + + config = { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "num_layers": self.num_layers, + "dense_kernel_initializer": dense_kernel_initializer, + "attention_dropout_rate": self.attention_dropout_rate, + "dense_dropout_rate": self.dense_dropout_rate, + "norm_epsilon": self.norm_epsilon, + "norm_output": self.norm_output, + } + base_config = super(FeatureFusionDecoder, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if isinstance( + config["dense_kernel_initializer"], tf.keras.initializers.Initializer + ): + config["dense_kernel_initializer"] = tf.keras.initializers.deserialize( + config["dense_kernel_initializer"] + ) + + return cls(**config) + + +def BloodhoundFunctional( + query_shape, + candidates_shape, + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + include_top=True, + pooling=None, + name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + + inputs_q = tf.keras.layers.Input(shape=query_shape) + inputs_c = tf.keras.layers.Input(shape=candidates_shape) + + # NOTE: positional embedding here if encoder is not a transformer + q = ExpandDims(axis=1, name="q_expand")(inputs_q) + c = ExpandDims(axis=0, name="c_expand")(inputs_c) + q, c = FeatureFusionDecoder( + embed_dim=c.shape[-1], + num_heads=n_heads, + ff_dim=ff_dim, + num_layers=n_layers, + attention_dropout_rate=dropout_rate, + dense_dropout_rate=dropout_rate, + norm_output=False, + name="decoder", + )([q, c]) + + if pooling is not None: + q = _Pool4DAxis2(method=pooling, name="pool_q")(q) + c = _Pool4DAxis2(method=pooling, name="pool_c")(c) + + if include_top: + if pooling is None: + raise ValueError( + "`include_top=True` requires `pooling` to be either 'avg', 'max', 'sum', or 'cls'." + ) + + x = CosineSimilarity(axis=-1)([q, c]) + # x = tf.keras.layers.Dense(1, activation="sigmoid", dtype=tf.float32)(c) + # x = tf.keras.layers.Lambda(lambda x: x[:, :, 0])(x) + else: + x = [q, c] + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=name) + return model + + +def Bloodhound4D( + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + query_tensor=None, + query_shape=None, + candidates_tensor=None, + candidates_shape=None, + include_top=True, + weights="imagenet21k+_224", + pooling=None, + model_name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + enc = backbones.ViTB16( + weights="imagenet21k+_224", + pooling=None, + include_top=False, + ) + patch_size = enc.get_layer("patch_embeddings").get_layer("embedding").kernel_size + enc = tf.keras.Model(enc.inputs, enc.outputs, name="encoder") + + inputs_q = _obtain_inputs( + query_tensor, + query_shape, + default_size=224, + min_size=patch_size, + weights=weights, + model_name=model_name, + name="query", + ) + inputs_c = _obtain_inputs( + candidates_tensor, + candidates_shape, + default_size=224, + min_size=patch_size, + weights=weights, + model_name=model_name, + name="candidates", + ) + + q_enc = enc(inputs_q) + c_enc = enc(inputs_c) + # NOTE: positional embedding here if encoder is not a transformer + dec = BloodhoundFunctional( + query_shape=q_enc.shape[1:], + candidates_shape=c_enc.shape[1:], + n_layers=n_layers, + n_heads=n_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + include_top=include_top, + pooling=pooling, + name="decoder", + ) + x = dec([q_enc, c_enc]) + + x = tf.keras.layers.Activation("linear", dtype=tf.float32, name="cast_float32")(x) + + if query_tensor is not None: + inputs_q = layer_utils.get_source_inputs(query_tensor) + if candidates_tensor is not None: + inputs_c = layer_utils.get_source_inputs(candidates_tensor) + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=model_name) + return model + + +def BloodhoundRes( + embed_dim, + n_layers, + n_heads, + ff_dim, + dropout_rate=0.1, + query_tensor=None, + query_shape=None, + candidates_tensor=None, + candidates_shape=None, + include_top=True, + weights="imagenet21k+_224", + pooling=None, + model_name=None, +): + """ + q -> encoder -> -----> zq + | + v + c -> encoder -> decoder -> zc + """ + inputs_q = _obtain_inputs( + query_tensor, + query_shape, + default_size=224, + min_size=32, + weights=weights, + model_name=model_name, + name="query", + ) + inputs_c = _obtain_inputs( + candidates_tensor, + candidates_shape, + default_size=224, + min_size=32, + weights=weights, + model_name=model_name, + name="candidates", + ) + + enc = tf.keras.applications.ResNet50( + weights="imagenet", include_top=False, input_shape=(224, 224, 3) + ) + enc = tf.keras.Model( + inputs=enc.inputs, outputs=enc.get_layer("conv4_block6_out").output + ) + x = enc.output + x = tf.keras.layers.Conv2D(filters=embed_dim, kernel_size=1)(x) + x = PositionalEmbedding2D(embedding_dim=embed_dim, add_to_input=True)(x) + x = tf.keras.layers.Reshape([x.shape[1] * x.shape[2], x.shape[3]])(x) + enc = tf.keras.Model(inputs=enc.inputs, outputs=x, name="encoder") + + q = enc(inputs_q) + c = enc(inputs_c) + + # NOTE: positional embedding here if encoder is not a transformer + dec = BloodhoundFunctional( + query_shape=q.shape[1:], + candidates_shape=c.shape[1:], + n_layers=n_layers, + n_heads=n_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + include_top=include_top, + pooling=pooling, + name="decoder", + ) + x = dec([q, c]) + + x = tf.keras.layers.Activation("linear", dtype=tf.float32, name="cast_float32")(x) + + if query_tensor is not None: + inputs_q = layer_utils.get_source_inputs(query_tensor) + if candidates_tensor is not None: + inputs_c = layer_utils.get_source_inputs(candidates_tensor) + + model = tf.keras.Model(inputs=[inputs_q, inputs_c], outputs=x, name=model_name) + return model diff --git a/chambers/utils/generic.py b/chambers/utils/generic.py index 9f35573..3b2bd6d 100644 --- a/chambers/utils/generic.py +++ b/chambers/utils/generic.py @@ -171,3 +171,11 @@ def add(self, n): n = tf.cast(n, tf.int32) self._steps.assign_add(n) self._report_progress() + + def dataset_apply_fn(self, dataset): + def prog_fn(*args): + self.add(1) + return args + + dataset = dataset.map(prog_fn) + return dataset \ No newline at end of file diff --git a/chambers/utils/ranking.py b/chambers/utils/ranking.py index dd81877..f5ff85e 100644 --- a/chambers/utils/ranking.py +++ b/chambers/utils/ranking.py @@ -22,3 +22,14 @@ def score_matrix_to_binary_ranking( ) return binary_ranking + + +def rank_labels(y, scores, remove_top1=False): + index_ranking = tf.argsort(scores, axis=1, direction="DESCENDING") + + if remove_top1: + index_ranking = index_ranking[:, 1:] + + gather_idx = arg_to_gather_nd(index_ranking) + ranking = tf.reshape(tf.gather_nd(y, gather_idx), index_ranking.shape) + return ranking, index_ranking \ No newline at end of file