Skip to content
This repository was archived by the owner on Aug 9, 2023. It is now read-only.

Commit 43e5dcf

Browse files
authored
Merge pull request #40 from wellcometrust/feature/nsorros/attention
Add attention
2 parents 1b570b8 + fc4a13f commit 43e5dcf

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

wellcomeml/ml/attention.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tensorflow as tf
2+
3+
class SelfAttention(tf.keras.layers.Layer):
4+
"""https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf"""
5+
def __init__(self, attention_dim=20):
6+
super(SelfAttention, self).__init__()
7+
self.attention_dim = attention_dim
8+
9+
def build(self, input_shape):
10+
self.WQ = self.add_weight(shape=(input_shape[-1], self.attention_dim), trainable=True, initializer='uniform')
11+
self.WK = self.add_weight(shape=(input_shape[-1], self.attention_dim), trainable=True, initializer='uniform')
12+
self.WV = self.add_weight(shape=(input_shape[-1], input_shape[-1]), trainable=True, initializer='uniform')
13+
14+
def call(self, X):
15+
"""
16+
In: (batch_size, sequence_length, embedding_dimension)
17+
Out: (batch_size, sequence_length, embedding_dimension)
18+
"""
19+
Q = tf.matmul(X, self.WQ)
20+
K = tf.matmul(X, self.WK)
21+
V = tf.matmul(X, self.WV)
22+
23+
attention_scores = tf.nn.softmax(tf.matmul(Q, tf.transpose(K, perm=[0,2,1])))
24+
return tf.matmul(attention_scores, V)
25+
26+
class FeedForwardAttention(tf.keras.layers.Layer):
27+
"""https://colinraffel.com/publications/iclr2016feed.pdf"""
28+
def __init__(self):
29+
super(FeedForwardAttention, self).__init__()
30+
31+
def build(self, input_shape):
32+
self.W = self.add_weight(shape=(input_shape[-1],1), trainable=True, initializer='uniform')
33+
34+
def call(self, X):
35+
"""
36+
In: (batch_size, sequence_length, embedding_dimension)
37+
Out: (batch_size, embedding_dimension)
38+
"""
39+
e = tf.math.tanh(tf.matmul(X, self.W))
40+
attention_scores = tf.nn.softmax(e)
41+
return tf.matmul(tf.transpose(X, perm=[0,2,1]), attention_scores)
42+
43+
class HierarchicalAttention(tf.keras.layers.Layer):
44+
"""https://www.aclweb.org/anthology/N16-1174/"""
45+
def __init__(self):
46+
super(HierarchicalAttention, self).__init__()
47+
48+
def build(self, input_shape):
49+
self.attention_matrix = self.add_weight(shape=(input_shape[-1], input_shape[-2]), trainable=True, initializer='uniform')
50+
51+
def call(self, X):
52+
"""
53+
In: (batch_size, sequence_length, embedding_dimension)
54+
Out: (batch_size, sequence_length, embedding_dimension)
55+
"""
56+
attention_scores = tf.nn.softmax(tf.math.tanh(tf.matmul(X, self.attention_matrix)))
57+
return tf.matmul(attention_scores, X)

wellcomeml/ml/bilstm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
from sklearn.metrics import f1_score
44
import tensorflow as tf
55

6+
from wellcomeml.ml.attention import HierarchicalAttention
67
from wellcomeml.ml.keras_utils import Metrics
78

89
class BiLSTMClassifier(BaseEstimator, ClassifierMixin):
910
def __init__(self, learning_rate=0.01, batch_size=32, nb_epochs=5,
10-
dropout=0.1, nb_layers=2, multilabel=False):
11+
dropout=0.1, nb_layers=2, multilabel=False,
12+
attention=False):
1113
self.learning_rate = learning_rate
1214
self.batch_size = batch_size
1315
self.nb_epochs = nb_epochs
1416
self.dropout = dropout
1517
self.nb_layers = nb_layers
1618
self.multilabel = multilabel
19+
self.attention = attention
1720

1821
def fit(self, X, Y, embedding_matrix=None, *_):
1922
sequence_length = X.shape[1]
@@ -29,6 +32,12 @@ def residual_bilstm(x1, l2):
2932
x2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(int(x1.shape[-1]/2), return_sequences=True, kernel_regularizer=l2))(x1)
3033
return tf.keras.layers.add([x1, x2])
3134

35+
def residual_attention(x1):
36+
x2 = HierarchicalAttention()(x1)
37+
x2 = tf.keras.layers.Dropout(self.dropout)(x2)
38+
x2 = tf.keras.layers.LayerNormalization()(x2)
39+
return tf.keras.layers.add([x1, x2])
40+
3241
l2 = tf.keras.regularizers.l2(1e-6)
3342
embeddings_initializer = tf.keras.initializers.Constant(embedding_matrix) if embedding_matrix else 'uniform'
3443
inp = tf.keras.layers.Input(shape=(sequence_length,))
@@ -40,6 +49,8 @@ def residual_bilstm(x1, l2):
4049
)(inp)
4150
for _ in range(self.nb_layers):
4251
x = residual_bilstm(x, l2)
52+
if self.attention:
53+
x = residual_attention(x)
4354
x = tf.keras.layers.GlobalMaxPooling1D()(x)
4455
x = tf.keras.layers.Dense(20, kernel_regularizer=l2)(x)
4556
out = tf.keras.layers.Dense(nb_outputs, activation=output_activation, kernel_regularizer=l2)(x)

wellcomeml/ml/cnn.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from sklearn.metrics import f1_score, precision_score, recall_score
1616
import tensorflow as tf
1717

18+
from wellcomeml.ml.attention import HierarchicalAttention
1819
from wellcomeml.ml.keras_utils import Metrics
1920

2021
class CNNClassifier(BaseEstimator, ClassifierMixin):
2122
def __init__(self, context_window = 3, learning_rate=0.001,
2223
batch_size=32, nb_epochs=5, dropout=0.2,
23-
nb_layers=4, hidden_size=100, multilabel=False):
24+
nb_layers=4, hidden_size=100, multilabel=False,
25+
attention=False):
2426
self.context_window = context_window
2527
self.learning_rate = learning_rate
2628
self.batch_size = batch_size
@@ -29,6 +31,7 @@ def __init__(self, context_window = 3, learning_rate=0.001,
2931
self.nb_layers = nb_layers
3032
self.hidden_size = hidden_size # note that on current implementation CNN use same hidden size as embedding so if embedding matrix is passed, this is not used. in the future we can decouple
3133
self.multilabel = multilabel
34+
self.attention = attention
3235

3336
def fit(self, X, Y, embedding_matrix=None):
3437
sequence_length = X.shape[1]
@@ -49,6 +52,12 @@ def residual_conv_block(x1):
4952
x2 = tf.keras.layers.LayerNormalization()(x2)
5053
return tf.keras.layers.add([x1, x2])
5154

55+
def residual_attention(x1):
56+
x2 = HierarchicalAttention()(x1)
57+
x2 = tf.keras.layers.Dropout(self.dropout)(x2)
58+
x2 = tf.keras.layers.LayerNormalization()(x2)
59+
return tf.keras.layers.add([x1, x2])
60+
5261
embeddings_initializer = tf.keras.initializers.Constant(embedding_matrix) if embedding_matrix else 'uniform'
5362
inp = tf.keras.layers.Input(shape=(sequence_length,))
5463
x = tf.keras.layers.Embedding(
@@ -59,6 +68,8 @@ def residual_conv_block(x1):
5968
x = tf.keras.layers.LayerNormalization()(x)
6069
for i in range(self.nb_layers):
6170
x = residual_conv_block(x)
71+
if self.attention:
72+
x = residual_attention(x)
6273
x = tf.keras.layers.GlobalMaxPooling1D()(x)
6374
x = tf.keras.layers.Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-6))(x)
6475
x = tf.keras.layers.Dropout(self.dropout)(x)

0 commit comments

Comments
 (0)