Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d9de0a5

Browse files
committedJan 2, 2025·
basic webui
1 parent b9cb645 commit d9de0a5

File tree

2 files changed

+319
-19
lines changed

2 files changed

+319
-19
lines changed
 

‎examples/server/main.cpp

+315-16
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,8 @@ void worker_thread() {
854854
task_queue.pop();
855855
lock.unlock();
856856
task();
857-
is_busy = false;
857+
is_busy = false;
858+
running_task_id = "";
858859
}
859860
}
860861
}
@@ -926,14 +927,16 @@ void start_server(SDParams params) {
926927
using json = nlohmann::json;
927928
std::string task_id = std::to_string(std::chrono::system_clock::now().time_since_epoch().count());
928929

929-
json pending_task_json = json::object();
930-
pending_task_json["status"] = "Pending";
931-
pending_task_json["data"] = json::array();
932-
pending_task_json["step"] = -1;
933-
pending_task_json["eta"] = "?";
930+
{
931+
json pending_task_json = json::object();
932+
pending_task_json["status"] = "Pending";
933+
pending_task_json["data"] = json::array();
934+
pending_task_json["step"] = -1;
935+
pending_task_json["eta"] = "?";
934936

935-
std::lock_guard<std::mutex> results_lock(results_mutex);
936-
task_results[task_id] = pending_task_json;
937+
std::lock_guard<std::mutex> results_lock(results_mutex);
938+
task_results[task_id] = pending_task_json;
939+
}
937940

938941
auto task = [&req, &sd_ctx, &params, &n_prompts, task_id]() {
939942
running_task_id = task_id;
@@ -1044,12 +1047,7 @@ void start_server(SDParams params) {
10441047
end_task_json["eta"] = "?";
10451048
std::lock_guard<std::mutex> results_lock(results_mutex);
10461049
task_results[task_id] = end_task_json;
1047-
return;
10481050
}
1049-
1050-
std::lock_guard<std::mutex> results_lock(results_mutex);
1051-
task_results[task_id]["status"] = "Failed";
1052-
return;
10531051
};
10541052
// Add the task to the queue
10551053
add_task(task_id, task);
@@ -1059,20 +1057,321 @@ void start_server(SDParams params) {
10591057
res.set_content(response.dump(), "application/json");
10601058
});
10611059

1062-
svr->Post("/result", [](const httplib::Request& req, httplib::Response& res) {
1060+
svr->Get("/params", [&params](const httplib::Request& req, httplib::Response& res) {
1061+
using json = nlohmann::json;
1062+
json response;
1063+
json params_json = json::object();
1064+
params_json["prompt"] = params.lastRequest.prompt;
1065+
params_json["negative_prompt"] = params.lastRequest.negative_prompt;
1066+
params_json["clip_skip"] = params.lastRequest.clip_skip;
1067+
params_json["cfg_scale"] = params.lastRequest.cfg_scale;
1068+
params_json["guidance"] = params.lastRequest.guidance;
1069+
params_json["width"] = params.lastRequest.width;
1070+
params_json["height"] = params.lastRequest.height;
1071+
params_json["sample_method"] = sample_method_str[params.lastRequest.sample_method];
1072+
params_json["sample_steps"] = params.lastRequest.sample_steps;
1073+
params_json["seed"] = params.lastRequest.seed;
1074+
params_json["batch_count"] = params.lastRequest.batch_count;
1075+
params_json["normalize_input"] = params.lastRequest.normalize_input;
1076+
// params_json["input_id_images_path"] = params.input_id_images_path;
1077+
response["generation_params"] = params_json;
1078+
1079+
json context_params = json::object();
1080+
// Do not expose paths
1081+
// context_params["model_path"] = params.ctxParams.model_path;
1082+
// context_params["clip_l_path"] = params.ctxParams.clip_l_path;
1083+
// context_params["clip_g_path"] = params.ctxParams.clip_g_path;
1084+
// context_params["t5xxl_path"] = params.ctxParams.t5xxl_path;
1085+
// context_params["diffusion_model_path"] = params.ctxParams.diffusion_model_path;
1086+
// context_params["vae_path"] = params.ctxParams.vae_path;
1087+
// context_params["controlnet_path"] = params.ctxParams.controlnet_path;
1088+
context_params["lora_model_dir"] = params.ctxParams.lora_model_dir;
1089+
// context_params["embeddings_path"] = params.ctxParams.embeddings_path;
1090+
// context_params["stacked_id_embeddings_path"] = params.ctxParams.stacked_id_embeddings_path;
1091+
context_params["vae_decode_only"] = params.ctxParams.vae_decode_only;
1092+
context_params["vae_tiling"] = params.ctxParams.vae_tiling;
1093+
context_params["n_threads"] = params.ctxParams.n_threads;
1094+
context_params["wtype"] = params.ctxParams.wtype;
1095+
context_params["rng_type"] = params.ctxParams.rng_type;
1096+
context_params["schedule"] = params.ctxParams.schedule;
1097+
context_params["clip_on_cpu"] = params.ctxParams.clip_on_cpu;
1098+
context_params["control_net_cpu"] = params.ctxParams.control_net_cpu;
1099+
context_params["vae_on_cpu"] = params.ctxParams.vae_on_cpu;
1100+
context_params["diffusion_flash_attn"] = params.ctxParams.diffusion_flash_attn;
1101+
response["context_params"] = context_params;
1102+
1103+
res.set_content(response.dump(), "application/json");
1104+
});
1105+
1106+
1107+
svr->Get("/result", [](const httplib::Request& req, httplib::Response& res) {
10631108
using json = nlohmann::json;
10641109
// Parse task ID from query parameters
10651110
try {
1066-
std::string task_id = json::parse(req.body)["task_id"];
1111+
std::string task_id = req.get_param_value("task_id");
10671112
std::lock_guard<std::mutex> lock(results_mutex);
10681113
if (task_results.find(task_id) != task_results.end()) {
10691114
json result = task_results[task_id];
10701115
res.set_content(result.dump(), "application/json");
1116+
// Erase data after sending
1117+
result["data"] = json::array();
1118+
task_results[task_id] = result;
10711119
} else {
10721120
res.set_content("Cannot find task " + task_id + " in queue", "text/plain");
10731121
}
10741122
} catch (...) {
1075-
sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid request body: %s\n", req.body.c_str());
1123+
sd_log(sd_log_level_t::SD_LOG_WARN, "Error when fetching result");
1124+
}
1125+
});
1126+
1127+
svr->Get("/sample_methods", [](const httplib::Request& req, httplib::Response& res) {
1128+
using json = nlohmann::json;
1129+
json response;
1130+
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
1131+
response.push_back(sample_method_str[m]);
1132+
}
1133+
res.set_content(response.dump(), "application/json");
1134+
});
1135+
1136+
svr->Get("/schedules", [](const httplib::Request& req, httplib::Response& res) {
1137+
using json = nlohmann::json;
1138+
json response;
1139+
for (int s = 0; s < N_SCHEDULES; s++) {
1140+
response.push_back(schedule_str[s]);
1141+
}
1142+
res.set_content(response.dump(), "application/json");
1143+
});
1144+
1145+
1146+
svr->Get("/index.html", [](const httplib::Request& req, httplib::Response& res) {
1147+
try {
1148+
std::string html_content = R"xxx(
1149+
<!DOCTYPE html>
1150+
<html>
1151+
<head>
1152+
<meta charset="UTF-8">
1153+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1154+
<title>SDCPP Server</title>
1155+
<style>
1156+
body {
1157+
font-family: Arial, sans-serif;
1158+
display: flex;
1159+
align-items: center;
1160+
justify-content: center;
1161+
height: 100vh;
1162+
margin: 0;
1163+
background-color: #f0f0f0;
1164+
}
1165+
.container {
1166+
display: flex;
1167+
width: 80%;
1168+
background: white;
1169+
padding: 20px;
1170+
border-radius: 10px;
1171+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
1172+
}
1173+
.input-group {
1174+
display: flex;
1175+
align-items: center;
1176+
margin-bottom: 10px;
1177+
}
1178+
.input-group label {
1179+
width: 150px;
1180+
text-align: right;
1181+
margin-right: 10px;
1182+
}
1183+
.prompt-input, .param-input {
1184+
width: 400px;
1185+
}
1186+
canvas {
1187+
border: 1px solid #ccc;
1188+
}
1189+
.left-section {
1190+
flex: 1;
1191+
padding-right: 20px;
1192+
}
1193+
.right-section {
1194+
flex: 1;
1195+
display: flex;
1196+
align-items: center;
1197+
justify-content: center;
1198+
}
1199+
</style>
1200+
</head>
1201+
<body>
1202+
<div class="container">
1203+
<div class="left-section">
1204+
<h1>SDCPP Server</h1>
1205+
<div id="prompts">
1206+
<div class="input-group">
1207+
<label for="prompt">Prompt:</label>
1208+
<input type="text" id="prompt" class="prompt-input">
1209+
</div>
1210+
<div class="input-group">
1211+
<label for="neg_prompt">Negative Prompt:</label>
1212+
<input type="text" id="neg_prompt" class="prompt-input">
1213+
</div>
1214+
</div>
1215+
<div id="params">
1216+
<div class="input-group">
1217+
<label for="width">Width:</label>
1218+
<input type="number" id="width" class="param-input">
1219+
</div>
1220+
<div class="input-group">
1221+
<label for="height">Height:</label>
1222+
<input type="number" id="height" class="param-input">
1223+
</div>
1224+
<div class="input-group">
1225+
<label for="cfg_scale">CFG Scale:</label>
1226+
<input type="number" id="cfg_scale" class="param-input">
1227+
</div>
1228+
<div class="input-group">
1229+
<label for="guidance">Guidance (Flux):</label>
1230+
<input type="number" id="guidance" class="param-input">
1231+
</div>
1232+
<div class="input-group">
1233+
<label for="steps">Steps:</label>
1234+
<input type="number" id="steps" class="param-input">
1235+
</div>
1236+
<div class="input-group">
1237+
<label for="sample_method">Sample Method:</label>
1238+
<select id="sample_method" class="param-input"></select>
1239+
</div>
1240+
<div class="input-group">
1241+
<label for="seed">Seed:</label>
1242+
<input type="number" id="seed" class="param-input">
1243+
</div>
1244+
<div class="input-group">
1245+
<label for="batch_count">Batch Count:</label>
1246+
<input type="number" id="batch_count" class="param-input">
1247+
</div>
1248+
</div>
1249+
<button onclick="generateImage()">Generate</button>
1250+
<a id="downloadLink" style="display: none;" download="generated_image.png">Download Image</a>
1251+
</div>
1252+
<div class="right-section">
1253+
<canvas id="imageCanvas" width="500" height="500"></canvas>
1254+
</div>
1255+
</div>
1256+
<script>
1257+
// Fetch sample methods from the server and populate the dropdown list
1258+
async function fetchSampleMethods() {
1259+
const response = await fetch('/sample_methods');
1260+
const data = await response.json();
1261+
1262+
const select = document.getElementById('sample_method');
1263+
data.forEach(method => {
1264+
const option = document.createElement('option');
1265+
option.value = method;
1266+
option.textContent = method;
1267+
select.appendChild(option);
1268+
});
1269+
}
1270+
1271+
// Call the function to fetch and populate the sample methods list
1272+
fetchSampleMethods();
1273+
1274+
// Fetch parameters from the server and populate the input fields
1275+
async function fetchParams() {
1276+
const response = await fetch('/params');
1277+
const data = await response.json();
1278+
1279+
document.getElementById('prompt').value = data.generation_params.prompt;
1280+
document.getElementById('neg_prompt').value = data.generation_params.negative_prompt;
1281+
document.getElementById('width').value = data.generation_params.width;
1282+
document.getElementById('height').value = data.generation_params.height;
1283+
document.getElementById('cfg_scale').value = data.generation_params.cfg_scale;
1284+
document.getElementById('guidance').value = data.generation_params.guidance;
1285+
document.getElementById('steps').value = data.generation_params.sample_steps;
1286+
document.getElementById('sample_method').value = data.generation_params.sample_method;
1287+
document.getElementById('seed').value = data.generation_params.seed;
1288+
document.getElementById('batch_count').value = data.generation_params.batch_count;
1289+
}
1290+
1291+
// Call the function to fetch and populate the input fields
1292+
fetchParams();
1293+
1294+
async function generateImage() {
1295+
const prompt = document.getElementById('prompt').value;
1296+
const neg_prompt = document.getElementById('neg_prompt').value;
1297+
const width = document.getElementById('width').value;
1298+
const height = document.getElementById('height').value;
1299+
const cfg_scale = document.getElementById('cfg_scale').value;
1300+
const steps = document.getElementById('steps').value;
1301+
const guidance = document.getElementById('guidance').value;
1302+
const sample_method = document.getElementById('sample_method').value;
1303+
const seed = document.getElementById('seed').value;
1304+
const batch_count = document.getElementById('batch_count').value;
1305+
const canvas = document.getElementById('imageCanvas');
1306+
const ctx = canvas.getContext('2d');
1307+
const downloadLink = document.getElementById('downloadLink');
1308+
1309+
const requestBody = {
1310+
prompt: prompt,
1311+
negative_prompt: neg_prompt,
1312+
...(width && { width: parseInt(width) }),
1313+
...(height && { height: parseInt(height) }),
1314+
...(cfg_scale && { cfg_scale: parseFloat(cfg_scale) }),
1315+
...(steps && { steps: parseInt(steps) }),
1316+
...(guidance && { guidance: parseFloat(guidance) }),
1317+
...(sample_method && { sample_method: sample_method }),
1318+
...(seed && { seed: parseInt(seed) }),
1319+
...(batch_count && { batch_count: parseInt(batch_count) })
1320+
};
1321+
1322+
const response = await fetch('/txt2img', {
1323+
method: 'POST',
1324+
headers: {
1325+
'Content-Type': 'application/json'
1326+
},
1327+
body: JSON.stringify(requestBody)
1328+
});
1329+
1330+
const data = await response.json();
1331+
const taskId = data.task_id;
1332+
1333+
let status = '';
1334+
while (status !== 'Done' && status !== 'Failed') {
1335+
const statusResponse = await fetch(`/result?task_id=${taskId}`);
1336+
const statusData = await statusResponse.json();
1337+
status = statusData.status;
1338+
1339+
if (status === 'Done' || status === 'Working' && statusData.data.length > 0 ) {
1340+
const imageData = statusData.data[0].data;
1341+
const width = statusData.data[0].width;
1342+
const height = statusData.data[0].height;
1343+
1344+
const img = new Image();
1345+
img.src = `data:image/png;base64,${imageData}`;
1346+
img.onload = () => {
1347+
canvas.width = width;
1348+
canvas.height = height;
1349+
ctx.drawImage(img, 0, 0, width, height);
1350+
downloadLink.href = img.src;
1351+
downloadLink.style.display = 'inline-block';
1352+
};
1353+
} else if (status === 'Failed') {
1354+
alert('Image generation failed');
1355+
}
1356+
1357+
await new Promise(resolve => setTimeout(resolve, 250));
1358+
}
1359+
}
1360+
document.querySelectorAll('.prompt-input,.param-input').forEach(input => {
1361+
input.addEventListener('keydown', function(event) {
1362+
if (event.key === 'Enter') {
1363+
event.preventDefault();
1364+
generateImage();
1365+
}
1366+
});
1367+
});
1368+
</script>
1369+
</body>
1370+
</html>
1371+
)xxx";
1372+
res.set_content(html_content, "text/html");
1373+
} catch (const std::exception& e) {
1374+
res.set_content("Error loading page", "text/plain");
10761375
}
10771376
});
10781377

‎examples/server/test_client.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ def update_url(protocol=None, server=None, port=None, endpoint=None) -> str:
8585
# set default url value
8686
update_url()
8787

88-
def poll_result(id: str):
88+
def poll_result(id: str, show_previews = False):
8989
global _protocol
9090
global _server
9191
global _port
9292

9393
res = {'status':""}
9494
while res['status'] != "Done":
95-
time.sleep(0.1)
96-
res = requests.post(f"{_protocol}://{_server}:{_port}/result", json.dumps({'task_id':id})).json()
95+
res = requests.get(f"{_protocol}://{_server}:{_port}/result", params={'task_id':id}, timeout=.25).json()
96+
if(show_previews and res['status'] == "Working" and len(res['data'])>0):
97+
showImages(getImages(json.dumps(res['data'])))
9798

9899
return json.dumps(res['data'])
99100

0 commit comments

Comments
 (0)
Please sign in to comment.