class PositionalEncoding nn Module def __init__ self d_model num_dims

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class PositionalEncoding(nn.Module):
def __init__(self, d_model, num_dims=3, min_timescale=1.0, max_timescale=1.0e4):
super().__init__()
self.d_model = d_model
self.num_dims = num_dims # t, h, w
self.num_timescales = self.d_model // (self.num_dims * 2)
log_timescale_increment = np.log(max_timescale / min_timescale) / (self.num_timescales - 1)
inv_timescales = min_timescale * torch.exp((torch.arange(self.num_timescales).float() * -log_timescale_increment))
self.register_buffer('inv_timescales', inv_timescales)
def forward(self, x):
b, c, t, h, w = x.shape
x = x.view(b, t, h, w, c)
for dim in range(self.num_dims):
length = x.shape[dim + 1] # add 1 to exclude batch dim
# cos or sin of (positions * inv_timescales)
position = torch.arange(length).float().to(x.device)
scaled_time = position.view(-1, 1) * self.inv_timescales.view(1, -1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
# add padding to signal up to d_model
prepad = dim * 2 * self.num_timescales
postpad = self.d_model - (dim + 1) * 2 * self.num_timescales
signal = F.pad(signal, (prepad, postpad))
# match dim of signal and x
for _ in range(1 + dim):
signal = signal.unsqueeze(0)
for _ in range(self.num_dims - 1 - dim):
signal = signal.unsqueeze(-2)
x += signal
x = x.view(b, c, t, h, w)
return x