Skip to content

Commit d50473d

Browse files
authored
feat: support 16 channel tae (taesd/taef1) (#527)
1 parent b5cc142 commit d50473d

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

stable-diffusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ class StableDiffusionGGML {
360360
first_stage_model->alloc_params_buffer();
361361
first_stage_model->get_param_tensors(tensors, "first_stage_model");
362362
} else {
363-
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only);
363+
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version);
364364
}
365365
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
366366

tae.hpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
6262
int num_blocks = 3;
6363

6464
public:
65-
TinyEncoder() {
65+
TinyEncoder(int z_channels = 4)
66+
: z_channels(z_channels) {
6667
int index = 0;
6768
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
6869
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
@@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
106107
int num_blocks = 3;
107108

108109
public:
109-
TinyDecoder(int index = 0) {
110+
TinyDecoder(int z_channels = 4)
111+
: z_channels(z_channels) {
112+
int index = 0;
113+
110114
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
111115
index++; // nn.ReLU()
112116

@@ -163,12 +167,16 @@ class TAESD : public GGMLBlock {
163167
bool decode_only;
164168

165169
public:
166-
TAESD(bool decode_only = true)
170+
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
167171
: decode_only(decode_only) {
168-
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder());
172+
int z_channels = 4;
173+
if (sd_version_is_dit(version)) {
174+
z_channels = 16;
175+
}
176+
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
169177

170178
if (!decode_only) {
171-
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder());
179+
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
172180
}
173181
}
174182

@@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
190198
TinyAutoEncoder(ggml_backend_t backend,
191199
std::map<std::string, enum ggml_type>& tensor_types,
192200
const std::string prefix,
193-
bool decoder_only = true)
201+
bool decoder_only = true,
202+
SDVersion version = VERSION_SD1)
194203
: decode_only(decoder_only),
195-
taesd(decode_only),
204+
taesd(decode_only, version),
196205
GGMLRunner(backend) {
197206
taesd.init(params_ctx, tensor_types, prefix);
198207
}

0 commit comments

Comments
 (0)