Skip to content

Feed Forward

FeedForward

Bases: Module

Source code in src/transformer/modules/feed_forward.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class FeedForward(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float):
        """Position-wise Feed-Forward block

        Args:
            d_model: dimension of the transformer model
            d_ff: hidden layer size in the feed forward block
            dropout: the percent dropout

        Note:
            See section 3.3 Position-wise Feed-Forward Networks of 
            "Attention is All You Need"
        """
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """Forward pass for the feed forward block

        Note: The tensor shapes are as follows:
              1. x: `(batch, seq_len, d_model)` ->
              2. Linear_1: `(batch, seq_len, d_model)` -> `(batch, seq_len, d_ff)`
              3. Linear_2: `(batch, seq_len, d_ff)` ->  `(batch, seq_len, d_model)`
        """
        x = self.linear_1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear_2(x)
        return x

__init__(d_model, d_ff, dropout)

Position-wise Feed-Forward block

Parameters:

Name Type Description Default
d_model int

dimension of the transformer model

required
d_ff int

hidden layer size in the feed forward block

required
dropout float

the percent dropout

required
Note

See section 3.3 Position-wise Feed-Forward Networks of "Attention is All You Need"

Source code in src/transformer/modules/feed_forward.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def __init__(self, d_model: int, d_ff: int, dropout: float):
    """Position-wise Feed-Forward block

    Args:
        d_model: dimension of the transformer model
        d_ff: hidden layer size in the feed forward block
        dropout: the percent dropout

    Note:
        See section 3.3 Position-wise Feed-Forward Networks of 
        "Attention is All You Need"
    """
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff)
    self.linear_2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

forward(x)

Forward pass for the feed forward block

The tensor shapes are as follows:
  1. x: (batch, seq_len, d_model) ->
  2. Linear_1: (batch, seq_len, d_model) -> (batch, seq_len, d_ff)
  3. Linear_2: (batch, seq_len, d_ff) -> (batch, seq_len, d_model)
Source code in src/transformer/modules/feed_forward.py
24
25
26
27
28
29
30
31
32
33
34
35
36
def forward(self, x):
    """Forward pass for the feed forward block

    Note: The tensor shapes are as follows:
          1. x: `(batch, seq_len, d_model)` ->
          2. Linear_1: `(batch, seq_len, d_model)` -> `(batch, seq_len, d_ff)`
          3. Linear_2: `(batch, seq_len, d_ff)` ->  `(batch, seq_len, d_model)`
    """
    x = self.linear_1(x)
    x = torch.relu(x)
    x = self.dropout(x)
    x = self.linear_2(x)
    return x