|
30 | 30 | import zero_7_exp_decay, zero_95_exp_decay, simple_sigmoid
|
31 | 31 | from ast import literal_eval
|
32 | 32 |
|
| 33 | +from custom.custom import GPT2Layer |
| 34 | + |
33 | 35 | #
|
34 | 36 | # Load the email data
|
35 | 37 | #
|
|
73 | 75 | INPUT_SHAPES = [()]
|
74 | 76 | OUTPUT_SHAPES = [1]
|
75 | 77 |
|
76 |
| -"""### A custom GPT2 encoder layer for text embedding""" |
77 |
| - |
78 |
| -class GPT2Layer(tf.keras.layers.Layer): |
79 |
| - |
80 |
| - def __init__(self, max_seq_length, **kwargs): |
81 |
| - # |
82 |
| - super(GPT2Layer, self).__init__(**kwargs) |
83 |
| - # |
84 |
| - # Load the GPT2 tokenizer, preprocessor and model |
85 |
| - self.tokenizer = GPT2Tokenizer.from_preset("gpt2_base_en") |
86 |
| - self.preprocessor = GPT2Preprocessor(self.tokenizer, |
87 |
| - sequence_length=max_seq_length) |
88 |
| - self.encoder = GPT2Backbone.from_preset("gpt2_base_en") |
89 |
| - # |
90 |
| - # Set whether the GPT2 model's layers are trainable |
91 |
| - #self.encoder.trainable = False |
92 |
| - for layer in self.encoder.layers: |
93 |
| - layer.trainable = False |
94 |
| - # |
95 |
| - self.encoder.layers[-2].trainable = True |
96 |
| - # |
97 |
| - # Set the maximum sequence length for tokenization |
98 |
| - self.max_seq_length = max_seq_length |
99 |
| - |
100 |
| - def call(self, inputs): |
101 |
| - # |
102 |
| - # Output the GPT2 embedding |
103 |
| - prep = self.preprocessor([inputs]) |
104 |
| - embedding = self.encoder(prep) |
105 |
| - avg_pool = tf.reduce_mean(embedding, axis=1) |
106 |
| - # |
107 |
| - return avg_pool |
108 |
| - |
109 |
| - def get_config(self): |
110 |
| - # |
111 |
| - config = super(GPT2Layer, self).get_config() |
112 |
| - config.update({'max_seq_length': self.max_seq_length}) |
113 |
| - # |
114 |
| - return config |
115 |
| - |
116 |
| - @classmethod |
117 |
| - def from_config(cls, config): |
118 |
| - # |
119 |
| - return cls(max_seq_length=config['max_seq_length']) |
| 78 | + |
| 79 | + |
120 | 80 |
|
121 | 81 | # GPT2 configurables
|
122 | 82 | max_seq_length = 96
|
|
0 commit comments