mirror of
https://github.com/Anil-matcha/Open-Generative-AI.git
synced 2026-05-07 01:17:18 +00:00
feat(image-studio): implement batch generation (1–4 images per click)
Closes #69 The Vite/Electron studio already exposed a 1–4 "Batch Count" slider but the value was only tracked in state and never applied to the request, so every generation returned a single image regardless of the slider. This wires the slider up and adds the same capability to the Next.js studio. Vite/Electron (src/components/ImageStudio.js): - Run the generation in N parallel slots when batchCount > 1 (clamped 1..4) - Each slot tracks its own request_id via savePendingJob/removePendingJob so pending-job resume and error handling stay per-image - When a fixed seed is set, offset it per slot (seed + i) so the batch yields distinct variations; a -1 seed stays random per call - Replace the single resultImg with a dynamic image container that renders one image full-size or a 2-col grid for 2–4 images - Grid tiles are clickable to expand to full size; Download button now targets the currently selected url - Partial failure is non-fatal — surface any successful results and log the rest; only throw if every slot failed Next.js studio (packages/studio/src/components/ImageStudio.jsx): - Add a batchCount state (persisted alongside the other studio prefs) - Add a compact "x1/x2/x3/x4" batch-size dropdown to the control row - Refactor handleGenerate to fan out N calls with Promise.allSettled, push each successful result into history, and keep the existing error UX when all calls fail The muapi /v1 endpoints don't expose a server-side batch parameter, so the fan-out happens client-side — one request per image.
This commit is contained in:
parent
9de0de3430
commit
dc0a552297
2 changed files with 222 additions and 64 deletions
|
|
@ -733,6 +733,8 @@ export default function ImageStudio({
|
|||
return resolutions[0] || null;
|
||||
});
|
||||
const [maxImages, setMaxImages] = useState(1);
|
||||
// Number of images to generate in parallel per "Generate" click (1..4).
|
||||
const [batchCount, setBatchCount] = useState(1);
|
||||
|
||||
// ── Prompt / upload state ───────────────────────────────────────────────
|
||||
const [prompt, setPrompt] = useState("");
|
||||
|
|
@ -781,6 +783,7 @@ export default function ImageStudio({
|
|||
if (data.selectedAr) setSelectedAr(data.selectedAr);
|
||||
if (data.selectedQuality) setSelectedQuality(data.selectedQuality);
|
||||
if (data.maxImages) setMaxImages(data.maxImages);
|
||||
if (data.batchCount) setBatchCount(data.batchCount);
|
||||
if (data.prompt) setPrompt(data.prompt);
|
||||
if (data.uploadedImageUrls) setUploadedImageUrls(data.uploadedImageUrls);
|
||||
if (data.localHistory) setLocalHistory(data.localHistory);
|
||||
|
|
@ -801,6 +804,7 @@ export default function ImageStudio({
|
|||
selectedAr,
|
||||
selectedQuality,
|
||||
maxImages,
|
||||
batchCount,
|
||||
prompt,
|
||||
uploadedImageUrls,
|
||||
localHistory,
|
||||
|
|
@ -818,6 +822,7 @@ export default function ImageStudio({
|
|||
selectedAr,
|
||||
selectedQuality,
|
||||
maxImages,
|
||||
batchCount,
|
||||
prompt,
|
||||
uploadedImageUrls,
|
||||
localHistory,
|
||||
|
|
@ -942,8 +947,11 @@ export default function ImageStudio({
|
|||
setGenerating(true);
|
||||
setGenerateError(null);
|
||||
|
||||
try {
|
||||
let res;
|
||||
const count = Math.max(1, Math.min(4, parseInt(batchCount, 10) || 1));
|
||||
const trimmedPrompt = prompt.trim();
|
||||
|
||||
// Build params once; each slot reuses the same payload (seed is left random).
|
||||
const buildParams = () => {
|
||||
if (imageMode) {
|
||||
const genParams = {
|
||||
model: selectedModelId,
|
||||
|
|
@ -951,28 +959,48 @@ export default function ImageStudio({
|
|||
image_url: uploadedImageUrls[0],
|
||||
aspect_ratio: selectedAr,
|
||||
};
|
||||
if (prompt.trim()) genParams.prompt = prompt.trim();
|
||||
if (trimmedPrompt) genParams.prompt = trimmedPrompt;
|
||||
if (currentQualityField && selectedQuality) {
|
||||
genParams[currentQualityField] = selectedQuality;
|
||||
}
|
||||
res = await generateI2I(apiKey, genParams);
|
||||
} else {
|
||||
const genParams = {
|
||||
model: selectedModelId,
|
||||
prompt: prompt.trim(),
|
||||
aspect_ratio: selectedAr,
|
||||
};
|
||||
if (currentQualityField && selectedQuality) {
|
||||
genParams[currentQualityField] = selectedQuality;
|
||||
}
|
||||
res = await generateImage(apiKey, genParams);
|
||||
return genParams;
|
||||
}
|
||||
const genParams = {
|
||||
model: selectedModelId,
|
||||
prompt: trimmedPrompt,
|
||||
aspect_ratio: selectedAr,
|
||||
};
|
||||
if (currentQualityField && selectedQuality) {
|
||||
genParams[currentQualityField] = selectedQuality;
|
||||
}
|
||||
return genParams;
|
||||
};
|
||||
|
||||
try {
|
||||
// Run all batch slots in parallel. Each call returns its own image URL
|
||||
// (the API doesn't expose a server-side batch option on the /v1 endpoints,
|
||||
// so we fan out client-side and treat each response as an independent item).
|
||||
const results = await Promise.allSettled(
|
||||
Array.from({ length: count }, () =>
|
||||
imageMode ? generateI2I(apiKey, buildParams()) : generateImage(apiKey, buildParams()),
|
||||
),
|
||||
);
|
||||
|
||||
const successes = results
|
||||
.filter((r) => r.status === "fulfilled" && r.value && r.value.url)
|
||||
.map((r) => r.value);
|
||||
|
||||
if (successes.length === 0) {
|
||||
const firstRejection = results.find((r) => r.status === "rejected");
|
||||
const err = firstRejection?.reason;
|
||||
throw err instanceof Error ? err : new Error("No image URL returned by API");
|
||||
}
|
||||
|
||||
if (res && res.url) {
|
||||
successes.forEach((res, i) => {
|
||||
const entry = {
|
||||
id: res.id || Date.now().toString(),
|
||||
id: res.id || `${Date.now()}-${i}`,
|
||||
url: res.url,
|
||||
prompt: prompt.trim(),
|
||||
prompt: trimmedPrompt,
|
||||
model: selectedModelId,
|
||||
aspect_ratio: selectedAr,
|
||||
timestamp: new Date().toISOString(),
|
||||
|
|
@ -981,11 +1009,14 @@ export default function ImageStudio({
|
|||
onGenerationComplete?.({
|
||||
url: res.url,
|
||||
model: selectedModelId,
|
||||
prompt: prompt.trim(),
|
||||
prompt: trimmedPrompt,
|
||||
type: "image",
|
||||
});
|
||||
} else {
|
||||
throw new Error("No image URL returned by API");
|
||||
});
|
||||
|
||||
const failureCount = results.length - successes.length;
|
||||
if (failureCount > 0) {
|
||||
console.warn(`[ImageStudio] ${failureCount}/${count} batch slot(s) failed`);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("[ImageStudio] Generation failed:", e);
|
||||
|
|
@ -1254,6 +1285,44 @@ export default function ImageStudio({
|
|||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Batch count button (1–4 images per generation) */}
|
||||
<div className="relative">
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setDropdownOpen((o) => (o === "batch" ? null : "batch"));
|
||||
}}
|
||||
title="Number of images per generation"
|
||||
className="flex items-center gap-2 px-3 py-2 bg-white/[0.03] hover:bg-white/[0.06] rounded-md transition-all border border-white/[0.03] group whitespace-nowrap"
|
||||
>
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" className="opacity-40 text-white">
|
||||
<rect x="3" y="3" width="7" height="7" rx="1" />
|
||||
<rect x="14" y="3" width="7" height="7" rx="1" />
|
||||
<rect x="3" y="14" width="7" height="7" rx="1" />
|
||||
<rect x="14" y="14" width="7" height="7" rx="1" />
|
||||
</svg>
|
||||
<span className="text-[11px] font-semibold text-white/70 group-hover:text-[#d9ff00] transition-colors">
|
||||
x{batchCount}
|
||||
</span>
|
||||
</button>
|
||||
|
||||
{dropdownOpen === "batch" && (
|
||||
<div
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute bottom-[calc(100%+12px)] left-0 z-50 bg-[#0a0a0a] rounded-md p-3 shadow-2xl border border-white/[0.05] min-w-[140px]"
|
||||
>
|
||||
<SimpleDropdown
|
||||
title="Batch Size"
|
||||
options={[1, 2, 3, 4]}
|
||||
selected={batchCount}
|
||||
onSelect={(val) => setBatchCount(val)}
|
||||
onClose={() => setDropdownOpen(null)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Generate button */}
|
||||
|
|
|
|||
|
|
@ -844,9 +844,9 @@ export function ImageStudio() {
|
|||
const imageContainer = document.createElement('div');
|
||||
imageContainer.className = 'relative group';
|
||||
|
||||
const resultImg = document.createElement('img');
|
||||
resultImg.className = 'max-h-[60vh] max-w-[80vw] rounded-3xl shadow-3xl border border-white/10 interactive-glow object-contain';
|
||||
imageContainer.appendChild(resultImg);
|
||||
// Tracks the URL currently featured in the canvas (used by the Download button
|
||||
// and to know which image out of a batch the user has selected).
|
||||
let currentCanvasUrl = null;
|
||||
|
||||
// Canvas Controls
|
||||
const canvasControls = document.createElement('div');
|
||||
|
|
@ -872,21 +872,56 @@ export function ImageStudio() {
|
|||
canvas.appendChild(canvasControls);
|
||||
container.appendChild(canvas);
|
||||
|
||||
// --- Helper: Show image in canvas ---
|
||||
const showImageInCanvas = (imageUrl) => {
|
||||
// --- Helper: Reveal the canvas once the first image has loaded ---
|
||||
const revealCanvas = () => {
|
||||
canvas.classList.remove('opacity-0', 'pointer-events-none', 'translate-y-10', 'scale-95');
|
||||
canvas.classList.add('opacity-100', 'translate-y-0', 'scale-100');
|
||||
canvasControls.classList.remove('opacity-0');
|
||||
canvasControls.classList.add('opacity-100');
|
||||
};
|
||||
|
||||
// --- Helper: Show one or many images in the canvas ---
|
||||
// A single URL renders as before (big featured image). Multiple URLs render as
|
||||
// a responsive grid where each tile can be clicked to expand to full size.
|
||||
const showImagesInCanvas = (imageUrls) => {
|
||||
const urls = Array.isArray(imageUrls) ? imageUrls : [imageUrls];
|
||||
if (urls.length === 0) return;
|
||||
|
||||
// Fully hide hero and prompt
|
||||
hero.classList.add('hidden');
|
||||
promptWrapper.classList.add('hidden');
|
||||
|
||||
resultImg.src = imageUrl;
|
||||
resultImg.onload = () => {
|
||||
canvas.classList.remove('opacity-0', 'pointer-events-none', 'translate-y-10', 'scale-95');
|
||||
canvas.classList.add('opacity-100', 'translate-y-0', 'scale-100');
|
||||
canvasControls.classList.remove('opacity-0');
|
||||
canvasControls.classList.add('opacity-100');
|
||||
};
|
||||
imageContainer.innerHTML = '';
|
||||
currentCanvasUrl = urls[0];
|
||||
|
||||
if (urls.length === 1) {
|
||||
const img = document.createElement('img');
|
||||
img.className = 'max-h-[60vh] max-w-[80vw] rounded-3xl shadow-3xl border border-white/10 interactive-glow object-contain';
|
||||
img.src = urls[0];
|
||||
img.onload = revealCanvas;
|
||||
imageContainer.appendChild(img);
|
||||
} else {
|
||||
// Up to 4 images — render as a 2-column grid (2x1 for 2, 2x2 for 3–4).
|
||||
const grid = document.createElement('div');
|
||||
grid.className = 'grid grid-cols-2 gap-3 md:gap-4 max-h-[70vh]';
|
||||
urls.forEach((u, i) => {
|
||||
const tile = document.createElement('img');
|
||||
tile.className = 'max-h-[32vh] max-w-[38vw] rounded-2xl shadow-2xl border border-white/10 object-contain cursor-pointer hover:scale-[1.02] hover:border-primary/50 transition-all';
|
||||
tile.src = u;
|
||||
if (i === 0) tile.onload = revealCanvas;
|
||||
tile.onclick = () => showImagesInCanvas([u]);
|
||||
tile.title = 'Click to expand';
|
||||
grid.appendChild(tile);
|
||||
});
|
||||
imageContainer.appendChild(grid);
|
||||
// Fallback reveal in case the first image is cached and onload already fired
|
||||
revealCanvas();
|
||||
}
|
||||
};
|
||||
|
||||
// Backward-compatible single-image wrapper
|
||||
const showImageInCanvas = (imageUrl) => showImagesInCanvas([imageUrl]);
|
||||
|
||||
// --- Helper: Add to history ---
|
||||
const addToHistory = (entry) => {
|
||||
generationHistory.unshift(entry);
|
||||
|
|
@ -1001,7 +1036,7 @@ export function ImageStudio() {
|
|||
|
||||
// --- Button Handlers ---
|
||||
downloadBtn.onclick = () => {
|
||||
const current = resultImg.src;
|
||||
const current = currentCanvasUrl;
|
||||
if (current) {
|
||||
const entry = generationHistory.find(e => e.url === current);
|
||||
downloadImage(current, `muapi-${entry?.id || 'image'}.jpg`);
|
||||
|
|
@ -1064,65 +1099,119 @@ export function ImageStudio() {
|
|||
|
||||
hero.classList.add('opacity-0', 'scale-95', '-translate-y-10', 'pointer-events-none');
|
||||
generateBtn.disabled = true;
|
||||
generateBtn.innerHTML = `<span class="animate-spin inline-block mr-2 text-black">◌</span> Generating...`;
|
||||
|
||||
let hadError = false;
|
||||
let capturedRequestId = null;
|
||||
// Clamp batch count to the slider's configured range as a defensive guard.
|
||||
const count = Math.max(1, Math.min(4, parseInt(batchCount, 10) || 1));
|
||||
const updateBtnProgress = (done) => {
|
||||
const label = count > 1 ? `Generating ${done}/${count}...` : 'Generating...';
|
||||
generateBtn.innerHTML = `<span class="animate-spin inline-block mr-2 text-black">◌</span> ${label}`;
|
||||
};
|
||||
updateBtnProgress(0);
|
||||
|
||||
const historyMeta = { prompt, model: selectedModel, aspect_ratio: selectedAr };
|
||||
const qualityLabel = document.getElementById('quality-btn-label')?.textContent;
|
||||
const qualityField = getCurrentQualityField(selectedModel);
|
||||
|
||||
try {
|
||||
let res;
|
||||
const qualityLabel = document.getElementById('quality-btn-label')?.textContent;
|
||||
// Build one set of generation params for a single slot in the batch.
|
||||
// Each slot tracks its own requestId so pending jobs can be individually
|
||||
// cleared (or resumed on reload) without one failure affecting the others.
|
||||
const buildSlot = (slotIdx) => {
|
||||
const slot = { requestId: null };
|
||||
const onRequestId = (rid) => {
|
||||
slot.requestId = rid;
|
||||
savePendingJob({
|
||||
requestId: rid,
|
||||
studioType: 'image',
|
||||
historyMeta,
|
||||
maxAttempts: 60,
|
||||
interval: 2000,
|
||||
submittedAt: Date.now()
|
||||
});
|
||||
};
|
||||
|
||||
let genParams;
|
||||
if (imageMode) {
|
||||
const genParams = {
|
||||
genParams = {
|
||||
model: selectedModel,
|
||||
images_list: uploadedImageUrls,
|
||||
image_url: uploadedImageUrls[0], // backward compat for single-image models
|
||||
aspect_ratio: selectedAr,
|
||||
onRequestId: (rid) => {
|
||||
capturedRequestId = rid;
|
||||
savePendingJob({ requestId: rid, studioType: 'image', historyMeta, maxAttempts: 60, interval: 2000, submittedAt: Date.now() });
|
||||
}
|
||||
onRequestId
|
||||
};
|
||||
if (prompt) genParams.prompt = prompt;
|
||||
const qualityField = getCurrentQualityField(selectedModel);
|
||||
if (qualityField && qualityLabel) genParams[qualityField] = qualityLabel;
|
||||
res = await muapi.generateI2I(genParams);
|
||||
} else {
|
||||
const genParams = {
|
||||
genParams = {
|
||||
model: selectedModel,
|
||||
prompt,
|
||||
aspect_ratio: selectedAr,
|
||||
onRequestId: (rid) => {
|
||||
capturedRequestId = rid;
|
||||
savePendingJob({ requestId: rid, studioType: 'image', historyMeta, maxAttempts: 60, interval: 2000, submittedAt: Date.now() });
|
||||
}
|
||||
onRequestId
|
||||
};
|
||||
const qualityField = getCurrentQualityField(selectedModel);
|
||||
if (qualityField && qualityLabel) genParams[qualityField] = qualityLabel;
|
||||
res = await muapi.generateImage(genParams);
|
||||
}
|
||||
if (qualityField && qualityLabel) genParams[qualityField] = qualityLabel;
|
||||
|
||||
// When the user has pinned a specific seed we offset it per slot so the
|
||||
// batch yields distinct variations rather than 4 identical images.
|
||||
// A seed of -1 means "random" and is left alone for each call.
|
||||
if (typeof seed === 'number' && seed !== -1) {
|
||||
genParams.seed = seed + slotIdx;
|
||||
}
|
||||
|
||||
console.log('[ImageStudio] Full response:', res);
|
||||
slot.params = genParams;
|
||||
return slot;
|
||||
};
|
||||
|
||||
if (res && res.url) {
|
||||
if (capturedRequestId) removePendingJob(capturedRequestId);
|
||||
const slots = Array.from({ length: count }, (_, i) => buildSlot(i));
|
||||
|
||||
// Kick off all slots in parallel and report progress as each one finishes.
|
||||
let completed = 0;
|
||||
const runSlot = async (slot) => {
|
||||
try {
|
||||
const res = imageMode
|
||||
? await muapi.generateI2I(slot.params)
|
||||
: await muapi.generateImage(slot.params);
|
||||
if (slot.requestId) removePendingJob(slot.requestId);
|
||||
if (!res || !res.url) throw new Error('No image URL returned by API');
|
||||
return { ok: true, res, slot };
|
||||
} catch (err) {
|
||||
if (slot.requestId) removePendingJob(slot.requestId);
|
||||
return { ok: false, error: err, slot };
|
||||
} finally {
|
||||
completed++;
|
||||
updateBtnProgress(completed);
|
||||
}
|
||||
};
|
||||
|
||||
let hadError = false;
|
||||
try {
|
||||
const results = await Promise.all(slots.map(runSlot));
|
||||
const successes = results.filter(r => r.ok);
|
||||
const failures = results.filter(r => !r.ok);
|
||||
|
||||
if (successes.length === 0) {
|
||||
const firstErr = failures[0]?.error;
|
||||
throw firstErr instanceof Error ? firstErr : new Error('All generations failed');
|
||||
}
|
||||
|
||||
const urls = successes.map(r => {
|
||||
const { res, slot } = r;
|
||||
addToHistory({
|
||||
id: res.id || capturedRequestId || Date.now().toString(),
|
||||
id: res.id || slot.requestId || `${Date.now()}-${Math.random().toString(36).slice(2, 7)}`,
|
||||
url: res.url,
|
||||
prompt: prompt,
|
||||
model: selectedModel,
|
||||
aspect_ratio: selectedAr,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
showImageInCanvas(res.url);
|
||||
} else {
|
||||
console.error('[ImageStudio] No image URL in response:', res);
|
||||
throw new Error('No image URL returned by API');
|
||||
return res.url;
|
||||
});
|
||||
|
||||
showImagesInCanvas(urls);
|
||||
|
||||
if (failures.length > 0) {
|
||||
console.warn(`[ImageStudio] ${failures.length}/${count} batch slot(s) failed:`, failures.map(f => f.error?.message));
|
||||
}
|
||||
} catch (e) {
|
||||
hadError = true;
|
||||
if (capturedRequestId) removePendingJob(capturedRequestId);
|
||||
console.error(e);
|
||||
// Restore hero so the page doesn't look broken after a failed generation
|
||||
hero.classList.remove('opacity-0', 'scale-95', '-translate-y-10', 'pointer-events-none');
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue