9
9
10
10
11
11
class MLP (nn .Module ):
12
- def __init__ (self , seq_len , num_channels , latent_space_size , gamma , normalization = "None" ,
13
- use_sigmoid_output = False , use_dropout = False , use_batchnorm = False ):
12
+ def __init__ (self , seq_len , num_channels , latent_space_size , gamma , normalization = "None" ):
14
13
super ().__init__ ()
15
14
self .L , self .C = seq_len , num_channels
16
- self .encoder = Encoder (seq_len * num_channels , latent_space_size , use_dropout , use_batchnorm )
17
- self .decoder = Decoder (seq_len * num_channels , latent_space_size , use_dropout , use_batchnorm )
15
+ self .encoder = Encoder (seq_len * num_channels , latent_space_size )
16
+ self .decoder = Decoder (seq_len * num_channels , latent_space_size )
18
17
self .normalization = normalization
19
18
20
19
if self .normalization == "Detrend" :
@@ -23,30 +22,24 @@ def __init__(self, seq_len, num_channels, latent_space_size, gamma, normalizatio
23
22
else :
24
23
self .use_normalizer = False
25
24
26
- self .use_sigmoid_output = use_sigmoid_output
27
- if self .use_sigmoid_output :
28
- self .sigmoid = torch .nn .Sigmoid ()
29
-
30
-
31
25
32
26
def forward (self , X ):
33
27
B , L , C = X .shape
34
28
assert (L == self .L ) and (C == self .C )
35
29
36
30
if self .use_normalizer :
37
31
X = self .normalizer (X , "norm" )
32
+
38
33
z = self .encoder (X .reshape (B , L * C ))
39
34
out = self .decoder (z ).reshape (B , L , C )
40
35
41
- if self .use_sigmoid_output :
42
- out = self .sigmoid (out )
43
36
if self .use_normalizer :
44
37
out = self .normalizer (out , "denorm" )
45
38
return out
46
39
47
40
48
41
class Encoder (nn .Module ):
49
- def __init__ (self , input_size , latent_space_size , use_dropout = False , use_batchnorm = False ):
42
+ def __init__ (self , input_size , latent_space_size ):
50
43
super ().__init__ ()
51
44
self .linear1 = nn .Linear (input_size , input_size // 2 )
52
45
self .relu1 = nn .ReLU ()
@@ -56,78 +49,31 @@ def __init__(self, input_size, latent_space_size, use_dropout=False, use_batchno
56
49
self .relu3 = nn .ReLU ()
57
50
58
51
59
- self .use_dropout = use_dropout
60
- if self .use_dropout :
61
- self .dropout1 = nn .Dropout (p = 0.2 )
62
- self .dropout2 = nn .Dropout (p = 0.2 )
63
- self .dropout3 = nn .Dropout (p = 0.2 )
64
-
65
- self .use_batchnorm = use_batchnorm
66
- if self .use_batchnorm :
67
- self .batchnorm1 = nn .BatchNorm1d (input_size // 2 )
68
- self .batchnorm2 = nn .BatchNorm1d (input_size // 4 )
69
- self .batchnorm3 = nn .BatchNorm1d (latent_space_size )
70
-
71
-
72
52
def forward (self , x ):
73
53
x = self .linear1 (x )
74
- if self .use_batchnorm :
75
- x = self .batchnorm1 (x )
76
54
x = self .relu1 (x )
77
- if self .use_dropout :
78
- x = self .dropout1 (x )
79
-
80
55
x = self .linear2 (x )
81
- if self .use_batchnorm :
82
- x = self .batchnorm2 (x )
83
56
x = self .relu2 (x )
84
- if self .use_dropout :
85
- x = self .dropout2 (x )
86
-
87
57
x = self .linear3 (x )
88
- if self .use_batchnorm :
89
- x = self .batchnorm3 (x )
90
58
x = self .relu3 (x )
91
- if self .use_dropout :
92
- x = self .dropout3 (x )
93
59
return x
94
60
95
61
96
62
class Decoder (nn .Module ):
97
- def __init__ (self , input_size , latent_space_size , use_dropout = False , use_batchnorm = False ):
63
+ def __init__ (self , input_size , latent_space_size ):
98
64
super ().__init__ ()
99
65
self .linear1 = nn .Linear (latent_space_size , input_size // 4 )
100
66
self .relu1 = nn .ReLU ()
101
67
self .linear2 = nn .Linear (input_size // 4 , input_size // 2 )
102
68
self .relu2 = nn .ReLU ()
103
69
self .linear3 = nn .Linear (input_size // 2 , input_size )
104
70
105
- self .use_dropout = use_dropout
106
- if self .use_dropout :
107
- self .dropout1 = nn .Dropout (p = 0.2 )
108
- self .dropout2 = nn .Dropout (p = 0.2 )
109
- self .dropout3 = nn .Dropout (p = 0.2 )
110
-
111
- self .use_batchnorm = use_batchnorm
112
- if self .use_batchnorm :
113
- self .batchnorm1 = nn .BatchNorm1d (input_size // 4 )
114
- self .batchnorm2 = nn .BatchNorm1d (input_size // 2 )
115
71
116
72
def forward (self , x ):
117
73
x = self .linear1 (x )
118
- if self .use_batchnorm :
119
- x = self .batchnorm1 (x )
120
74
x = self .relu1 (x )
121
- if self .use_dropout :
122
- x = self .dropout1 (x )
123
-
124
75
x = self .linear2 (x )
125
- if self .use_batchnorm :
126
- x = self .batchnorm2 (x )
127
76
x = self .relu2 (x )
128
- if self .use_dropout :
129
- x = self .dropout2 (x )
130
-
131
77
out = self .linear3 (x )
132
78
return out
133
79
0 commit comments