@@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
62
62
int num_blocks = 3 ;
63
63
64
64
public:
65
- TinyEncoder () {
65
+ TinyEncoder (int z_channels = 4 )
66
+ : z_channels(z_channels) {
66
67
int index = 0 ;
67
68
blocks[std::to_string (index ++)] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
68
69
blocks[std::to_string (index ++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels));
@@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
106
107
int num_blocks = 3 ;
107
108
108
109
public:
109
- TinyDecoder (int index = 0 ) {
110
+ TinyDecoder (int z_channels = 4 )
111
+ : z_channels(z_channels) {
112
+ int index = 0 ;
113
+
110
114
blocks[std::to_string (index ++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
111
115
index ++; // nn.ReLU()
112
116
@@ -163,12 +167,16 @@ class TAESD : public GGMLBlock {
163
167
bool decode_only;
164
168
165
169
public:
166
- TAESD (bool decode_only = true )
170
+ TAESD (bool decode_only = true , SDVersion version = VERSION_SD1 )
167
171
: decode_only(decode_only) {
168
- blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder ());
172
+ int z_channels = 4 ;
173
+ if (sd_version_is_dit (version)) {
174
+ z_channels = 16 ;
175
+ }
176
+ blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder (z_channels));
169
177
170
178
if (!decode_only) {
171
- blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder ());
179
+ blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder (z_channels ));
172
180
}
173
181
}
174
182
@@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
190
198
TinyAutoEncoder (ggml_backend_t backend,
191
199
std::map<std::string, enum ggml_type>& tensor_types,
192
200
const std::string prefix,
193
- bool decoder_only = true )
201
+ bool decoder_only = true ,
202
+ SDVersion version = VERSION_SD1)
194
203
: decode_only(decoder_only),
195
- taesd (decode_only),
204
+ taesd (decode_only, version ),
196
205
GGMLRunner(backend) {
197
206
taesd.init (params_ctx, tensor_types, prefix);
198
207
}
0 commit comments