import torch from src nn attention import MultiHeadAttention Positionw

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from src.nn.attention import MultiHeadAttention, PositionwiseFeedForward
import torch.nn as nn
class PointAttentionCLF(nn.Module):
def __init__(self, in_dim, hidden_dim=50, ffn_dim=100, n_head=4, dropout=0., kernel='linear', kernel_params=None):
super().__init__()
self.input_fc = torch.nn.Conv1d(in_dim, 64, 1)
self.attn_block = nn.Sequential(
MultiHeadAttention(n_head=n_head, d_model=hidden_dim, dropout=dropout,
use_residual=True, kernel=kernel, kernel_params=kernel_params),
PositionwiseFeedForward(hidden_dim, ffn_dim, d_out=hidden_dim, use_residual=True,
dropout=dropout),
MultiHeadAttention(n_head=n_head, d_model=hidden_dim, dropout=dropout,
use_residual=True, kernel=kernel, kernel_params=kernel_params),
PositionwiseFeedForward(hidden_dim, ffn_dim, d_out=hidden_dim, use_residual=True,
dropout=dropout),
MultiHeadAttention(n_head=n_head, d_model=hidden_dim, dropout=dropout,
use_residual=True, kernel=kernel, kernel_params=kernel_params),
PositionwiseFeedForward(hidden_dim, ffn_dim, d_out=hidden_dim, use_residual=True,
dropout=dropout),
MultiHeadAttention(n_head=n_head, d_model=hidden_dim, dropout=dropout,
use_residual=True, kernel=kernel, kernel_params=kernel_params),
PositionwiseFeedForward(hidden_dim, ffn_dim, d_out=hidden_dim, use_residual=True,
dropout=dropout),
MultiHeadAttention(n_head=n_head, d_model=hidden_dim, dropout=dropout,
use_residual=True, kernel=kernel, kernel_params=kernel_params),
PositionwiseFeedForward(hidden_dim, ffn_dim, d_out=hidden_dim * 4, use_residual=False,
dropout=dropout)
)
# self.attn_block = nn.DataParallel(self.attn_block, device_ids=[0, 1], output_device=0)
self.out_block = nn.Sequential(
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.BatchNorm1d(hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 40),
nn.Softmax()
)
def forward(self, x):
x = x.transpose(2, 1)
x = self.input_fc(x)
x = x.transpose(2, 1)
x = self.attn_block(x)
x = x.max(dim=1)[0]
x = self.out_block(x)
return x