Feed Forward
FeedForward
Bases: Module
Source code in src/transformer/modules/feed_forward.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
__init__(d_model, d_ff, dropout)
Position-wise Feed-Forward block
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
dimension of the transformer model |
required |
d_ff |
int
|
hidden layer size in the feed forward block |
required |
dropout |
float
|
the percent dropout |
required |
Note
See section 3.3 Position-wise Feed-Forward Networks of "Attention is All You Need"
Source code in src/transformer/modules/feed_forward.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
|
forward(x)
Forward pass for the feed forward block
The tensor shapes are as follows:
- x:
(batch, seq_len, d_model)
-> - Linear_1:
(batch, seq_len, d_model)
->(batch, seq_len, d_ff)
- Linear_2:
(batch, seq_len, d_ff)
->(batch, seq_len, d_model)
Source code in src/transformer/modules/feed_forward.py
24 25 26 27 28 29 30 31 32 33 34 35 36 |
|