@@ -29,7 +29,7 @@ def TransformerEncoder(mode='train',  # pylint: disable=invalid-name
2929                       feature_depth = 512 ,
3030                       feedforward_depth = 2048 ,
3131                       num_heads = 8 ,
32-                        dropout = 0.9  ):
32+                        dropout = 0.1  ):
3333  """Transformer Encoder Stack. 
3434
3535  Args: 
@@ -38,20 +38,22 @@ def TransformerEncoder(mode='train',  # pylint: disable=invalid-name
3838    feature_depth: int:  depth of embedding 
3939    feedforward_depth: int: depth of feed-forward layer 
4040    num_heads: int: number of attention heads 
41-     dropout: float: dropout rate - Stax follows TF's KEEP probability convention 
41+     dropout: float: dropout rate (how much to drop out; note that stax follows 
42+       Tensorflow's keep_rate convention, so we use 1 - dropout in calls below) 
4243
4344  Returns: 
4445    A staxlayer for implementing a raw Transformer encoder stack.  No embedding 
4546    or positional signals are added by this layer. 
4647  """ 
48+   keep_rate  =  1.0  -  dropout 
4749  # Multi-headed Attention and Feed-forward layers 
4850  multi_attention  =  stax .MultiHeadedAttention (
49-       feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
51+       feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
5052
5153  feed_forward  =  stax .serial (
5254      stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
5355      stax .Relu ,
54-       stax .Dropout (dropout , mode = mode ),
56+       stax .Dropout (keep_rate , mode = mode ),
5557      stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
5658  )
5759
@@ -74,11 +76,11 @@ def encoder(embedded_source, source_mask):
7476                                     stax .Identity ,  # value 
7577                                     source_mask ),  # attention mask 
7678                      multi_attention ,
77-                       stax .Dropout (dropout , mode = mode )),
79+                       stax .Dropout (keep_rate , mode = mode )),
7880        # feed-forward 
7981        stax .residual (stax .LayerNorm (feature_depth ),
8082                      feed_forward ,
81-                       stax .Dropout (dropout , mode = mode ))
83+                       stax .Dropout (keep_rate , mode = mode ))
8284    )
8385    return  stax .serial (
8486        embedded_source ,
@@ -95,8 +97,8 @@ def TransformerLM(vocab_size,  # pylint: disable=invalid-name
9597                  feature_depth = 512 ,
9698                  feedforward_depth = 2048 ,
9799                  num_heads = 8 ,
98-                   dropout = 0.9  ,
99-                   max_len = 256 ):
100+                   dropout = 0.1  ,
101+                   max_len = 512 ):
100102  """Transformer language model (only uses the decoder part of Transformer). 
101103
102104  Args: 
@@ -106,20 +108,21 @@ def TransformerLM(vocab_size,  # pylint: disable=invalid-name
106108    feature_depth: int:  depth of embedding 
107109    feedforward_depth: int: depth of feed-forward layer 
108110    num_heads: int: number of attention heads 
109-     dropout: float: dropout rate - Stax follows TF's KEEP probability convention  
111+     dropout: float: dropout rate (how much to drop out)  
110112    max_len: int: maximum symbol length for positional encoding 
111113
112114  Returns: 
113115    init and apply. 
114116  """ 
117+   keep_rate  =  1.0  -  dropout 
115118  # Multi-headed Attention and Feed-forward layers 
116119  multi_attention  =  stax .MultiHeadedAttention (
117-       feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
120+       feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
118121
119122  feed_forward  =  stax .serial (
120123      stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
121124      stax .Relu ,
122-       stax .Dropout (dropout , mode = mode ),
125+       stax .Dropout (keep_rate , mode = mode ),
123126      stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
124127  )
125128
@@ -132,18 +135,18 @@ def TransformerLM(vocab_size,  # pylint: disable=invalid-name
132135                                   stax .Identity ,  # value 
133136                                   stax .CausalMask (axis = - 2 )),  # attention mask 
134137                    multi_attention ,
135-                     stax .Dropout (dropout , mode = mode )),
138+                     stax .Dropout (keep_rate , mode = mode )),
136139      # feed-forward 
137140      stax .residual (stax .LayerNorm (feature_depth ),
138141                    feed_forward ,
139-                     stax .Dropout (dropout , mode = mode ))
142+                     stax .Dropout (keep_rate , mode = mode ))
140143  )
141144
142145  return  stax .serial (
143146      stax .ShiftRight (),
144147      stax .Embedding (feature_depth , vocab_size ),
145148      stax .PositionalEncoding (feature_depth , max_len = max_len ),
146-       stax .Dropout (dropout , mode = mode ),
149+       stax .Dropout (keep_rate , mode = mode ),
147150      stax .repeat (decoder_layer , num_layers ),
148151      stax .LayerNorm (feature_depth ),
149152      stax .Dense (vocab_size , W_init = stax .xavier_uniform ()),
@@ -158,7 +161,7 @@ def Transformer(source_vocab_size,  # pylint: disable=invalid-name
158161                feature_depth = 512 ,
159162                feedforward_depth = 2048 ,
160163                num_heads = 8 ,
161-                 dropout = 0.9  ,
164+                 dropout = 0.1  ,
162165                shared_embedding = True ,
163166                max_len = 200 ,
164167                return_evals = False ):
@@ -172,7 +175,7 @@ def Transformer(source_vocab_size,  # pylint: disable=invalid-name
172175    feature_depth: int:  depth of embedding 
173176    feedforward_depth: int: depth of feed-forward layer 
174177    num_heads: int: number of attention heads 
175-     dropout: float: dropout rate - Stax follows TF's KEEP probability convention  
178+     dropout: float: dropout rate (how much to drop out)  
176179    shared_embedding: bool: specify whether source/target embeddings are tied. 
177180    max_len: int: maximum symbol length for positional encoding 
178181    return_evals: bool: whether to generate decode-time evaluation functions 
@@ -182,11 +185,11 @@ def Transformer(source_vocab_size,  # pylint: disable=invalid-name
182185  the 'evals' functions that itself returns a namedtuple containing evaluation 
183186  functions for the trained encoder, decoder, and generator substax. 
184187  """ 
185- 
188+    keep_rate   =   1.0   -   dropout 
186189  # Input embedding and positional encoding 
187190  inject_position  =  stax .serial (
188191      stax .PositionalEncoding (feature_depth , max_len = max_len ),
189-       stax .Dropout (dropout , mode = mode )
192+       stax .Dropout (keep_rate , mode = mode )
190193  )
191194  if  shared_embedding :
192195    assert  source_vocab_size  ==  target_vocab_size 
@@ -202,12 +205,12 @@ def Transformer(source_vocab_size,  # pylint: disable=invalid-name
202205
203206  # Multi-headed Attention and Feed-forward layers 
204207  multi_attention  =  stax .MultiHeadedAttention (
205-       feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
208+       feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
206209
207210  feed_forward  =  stax .serial (
208211      stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
209212      stax .Relu ,
210-       stax .Dropout (dropout , mode = mode ),
213+       stax .Dropout (keep_rate , mode = mode ),
211214      stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
212215  )
213216
@@ -231,11 +234,11 @@ def encoder(source, source_mask):
231234                                     stax .Identity ,  # value 
232235                                     source_mask ),  # attention mask 
233236                      multi_attention ,
234-                       stax .Dropout (dropout , mode = mode )),
237+                       stax .Dropout (keep_rate , mode = mode )),
235238        # feed-forward 
236239        stax .residual (stax .LayerNorm (feature_depth ),
237240                      feed_forward ,
238-                       stax .Dropout (dropout , mode = mode ))
241+                       stax .Dropout (keep_rate , mode = mode ))
239242    )
240243    return  stax .serial (
241244        source ,
@@ -266,19 +269,19 @@ def decoder(memory, target, target_mask, memory_mask):
266269                                     stax .Identity ,  # value 
267270                                     target_mask ),  # attention mask 
268271                      multi_attention ,
269-                       stax .Dropout (dropout , mode = mode )),
272+                       stax .Dropout (keep_rate , mode = mode )),
270273        # target attends to encoded source 
271274        stax .residual (stax .LayerNorm (feature_depth ),
272275                      stax .multiplex (stax .Identity ,  # query 
273276                                     memory ,  # key 
274277                                     memory ,  # value 
275278                                     memory_mask ),  # attention mask 
276279                      multi_attention ,
277-                       stax .Dropout (dropout , mode = mode )),
280+                       stax .Dropout (keep_rate , mode = mode )),
278281        # feed-forward 
279282        stax .residual (stax .LayerNorm (feature_depth ),
280283                      feed_forward ,
281-                       stax .Dropout (dropout , mode = mode ))
284+                       stax .Dropout (keep_rate , mode = mode ))
282285    )
283286    return  stax .serial (
284287        target ,
0 commit comments