新增 连续生成功能;

添加了自动化测试套件;
This commit is contained in:
2025-10-02 18:13:44 +08:00
parent d7e355e9c6
commit d70e9e62b8
14 changed files with 985 additions and 47 deletions

View File

@@ -1,5 +1,6 @@
import React, { useRef, useEffect, useState, useCallback } from 'react';
import { Stage, Layer, Image as KonvaImage, Line } from 'react-konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { useAppStore } from '../store/useAppStore';
import { Button } from './ui/Button';
import { ZoomIn, ZoomOut, RotateCcw, Download } from 'lucide-react';
@@ -8,7 +9,7 @@ export const ImageCanvas: React.FC = () => {
const {
canvasImage,
canvasZoom,
canvasPan,
// canvasPan,
setCanvasZoom,
setCanvasPan,
brushStrokes,
@@ -16,6 +17,8 @@ export const ImageCanvas: React.FC = () => {
showMasks,
selectedTool,
isGenerating,
isContinuousGenerating,
retryCount,
brushSize,
showHistory,
showPromptPanel
@@ -296,14 +299,16 @@ export const ImageCanvas: React.FC = () => {
return () => container.removeEventListener('wheel', handleWheel);
}, [canvasZoom, handleZoom]);
const handleMouseDown = (e: Konva.KonvaEventObject<MouseEvent>) => {
const handleMouseDown = (e: KonvaEventObject<MouseEvent>) => {
if (selectedTool !== 'mask' || !image) return;
setIsDrawing(true);
const stage = e.target.getStage();
if (!stage) return;
// 使用 Konva 的 getRelativePointerPosition 获取准确坐标
const relativePos = stage.getRelativePointerPosition();
if (!relativePos) return;
// 计算图像在舞台上的边界
const imageX = (stageSize.width / canvasZoom - image.width) / 2;
@@ -319,13 +324,15 @@ export const ImageCanvas: React.FC = () => {
}
};
const handleMouseMove = (e: Konva.KonvaEventObject<MouseEvent>) => {
const handleMouseMove = (e: KonvaEventObject<MouseEvent>) => {
if (!isDrawing || selectedTool !== 'mask' || !image) return;
const stage = e.target.getStage();
if (!stage) return;
// 使用 Konva 的 getRelativePointerPosition 获取准确坐标
const relativePos = stage.getRelativePointerPosition();
if (!relativePos) return;
// 计算图像在舞台上的边界
const imageX = (stageSize.width / canvasZoom - image.width) / 2;
@@ -353,6 +360,7 @@ export const ImageCanvas: React.FC = () => {
id: `stroke-${Date.now()}`,
points: currentStroke,
brushSize,
color: '#A855F7',
});
setCurrentStroke([]);
};
@@ -424,12 +432,14 @@ export const ImageCanvas: React.FC = () => {
console.error('下载图像失败:', error);
// 如果fetch失败回退到直接使用a标签
const link = document.createElement('a');
link.href = uploadResult.url;
link.download = `nano-banana-${Date.now()}.png`;
link.target = '_blank';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
if (uploadResult.url) {
link.href = uploadResult.url;
link.download = `nano-banana-${Date.now()}.png`;
link.target = '_blank';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
}
});
// 立即返回
@@ -571,6 +581,12 @@ export const ImageCanvas: React.FC = () => {
<div className="text-center bg-white/90 rounded-xl p-6 card-lg backdrop-blur-sm animate-in scale-in duration-200">
<div className="animate-spin rounded-full h-10 w-10 border-2 border-yellow-400 border-t-transparent mx-auto mb-3" />
<p className="text-gray-700 text-sm font-medium">...</p>
{/* 显示重试次数 */}
{isContinuousGenerating && (
<p className="text-gray-500 text-xs mt-2">
: {retryCount}
</p>
)}
</div>
</div>
)}
@@ -592,8 +608,8 @@ export const ImageCanvas: React.FC = () => {
}
}}
onMouseDown={handleMouseDown}
onMousemove={handleMouseMove}
onMouseup={handleMouseUp}
onMouseMove={handleMouseMove}
onMouseUp={handleMouseUp}
style={{
cursor: selectedTool === 'mask' ? 'crosshair' : 'default',
zIndex: 10

View File

@@ -126,6 +126,8 @@ export const PromptComposer: React.FC = () => {
seed,
setSeed,
isGenerating,
isContinuousGenerating,
retryCount,
uploadedImages,
addUploadedImage,
removeUploadedImage,
@@ -139,11 +141,15 @@ export const PromptComposer: React.FC = () => {
setCanvasImage,
showPromptPanel,
setShowPromptPanel,
clearBrushStrokes
clearBrushStrokes,
setIsContinuousGenerating,
setRetryCount
} = useAppStore();
const { generate, cancelGeneration } = useImageGeneration();
const { generate, generateAsync, cancelGeneration } = useImageGeneration();
const { edit, cancelEdit } = useImageEditing();
// 连续生成状态已在AppStore中管理
const [showAdvanced, setShowAdvanced] = useState(false);
const [showPromptSuggestions, setShowPromptSuggestions] = useState(true);
const [showClearConfirm, setShowClearConfirm] = useState(false);
@@ -248,6 +254,121 @@ export const PromptComposer: React.FC = () => {
}
};
const handleContinuousGenerate = async () => {
if (!currentPrompt.trim()) return;
// 重置重试计数
setRetryCount(0);
setIsContinuousGenerating(true);
// 将上传的图像转换为Blob对象
const referenceImageBlobs: Blob[] = [];
for (const img of uploadedImages) {
if (img.startsWith('data:')) {
// 从base64数据创建Blob
const base64 = img.split('base64,')[1];
const byteString = atob(base64);
const mimeString = img.split(',')[0].split(':')[1].split(';')[0];
const ab = new ArrayBuffer(byteString.length);
const ia = new Uint8Array(ab);
for (let i = 0; i < byteString.length; i++) {
ia[i] = byteString.charCodeAt(i);
}
referenceImageBlobs.push(new Blob([ab], { type: mimeString }));
} else if (img.startsWith('indexeddb://')) {
// 从IndexedDB获取参考图像
const imageId = img.replace('indexeddb://', '');
try {
const blob = await referenceImageService.getReferenceImage(imageId);
if (blob) {
referenceImageBlobs.push(blob);
} else {
console.warn('无法从IndexedDB获取参考图像:', imageId);
// 如果无法获取图像,尝试重新上传
console.log('尝试重新处理参考图像...');
}
} catch (error) {
console.warn('无法从IndexedDB获取参考图像:', imageId, error);
// 如果无法获取图像,尝试重新上传
console.log('尝试重新处理参考图像...');
}
} else if (img.startsWith('blob:')) {
// 从Blob URL获取Blob
const { getBlob } = useAppStore.getState();
const blob = getBlob(img);
if (blob) {
referenceImageBlobs.push(blob);
} else {
// 如果在AppStore中找不到Blob尝试重新创建
try {
const response = await fetch(img);
if (response.ok) {
const blob = await response.blob();
referenceImageBlobs.push(blob);
} else {
console.warn('无法重新获取参考图像:', img);
}
} catch (error) {
console.warn('无法重新获取参考图像:', img, error);
}
}
} else {
// 从URL获取Blob
try {
const blob = await urlToBlob(img);
referenceImageBlobs.push(blob);
} catch (error) {
console.warn('无法获取参考图像:', img, error);
}
}
}
// 过滤掉无效的Blob只保留有效的参考图像
const validBlobs = referenceImageBlobs.filter(blob => blob.size > 0);
// 开始连续生成循环
const generateWithRetry = async () => {
try {
// 即使没有参考图像也继续生成,因为提示文本是必需的
await new Promise<void>((resolve, reject) => {
// 使用mutateAsync来等待结果
generateAsync({
prompt: currentPrompt,
referenceImages: validBlobs.length > 0 ? validBlobs : undefined,
temperature,
seed: seed !== null ? seed : undefined
}).then(() => {
// 生成成功,停止连续生成
setIsContinuousGenerating(false);
resolve();
}).catch((error) => {
// 生成失败,增加重试计数并继续
const newCount = useAppStore.getState().retryCount + 1;
setRetryCount(newCount);
console.log(`生成失败,重试次数: ${newCount}`);
reject(error);
});
});
} catch (error) {
// 如果仍在连续生成模式下,继续重试
if (useAppStore.getState().isContinuousGenerating) {
console.log('生成失败,正在重试...');
setTimeout(generateWithRetry, 1000); // 1秒后重试
}
}
};
// 启动连续生成
generateWithRetry();
};
// 取消连续生成
const cancelContinuousGeneration = () => {
setIsContinuousGenerating(false);
cancelGeneration();
};
const handleFileUpload = async (file: File) => {
if (file && file.type.startsWith('image/')) {
try {
@@ -329,7 +450,7 @@ export const PromptComposer: React.FC = () => {
e.dataTransfer.setData('text/plain', index.toString());
};
const handleDragOverPreview = (e: React.DragEvent<HTMLDivElement>, index: number) => {
const handleDragOverPreview = (e: React.DragEvent<HTMLDivElement>, _index: number) => {
e.preventDefault();
e.dataTransfer.dropEffect = 'move';
};
@@ -566,25 +687,45 @@ export const PromptComposer: React.FC = () => {
{/* 生成按钮 */}
<div className="flex-shrink-0">
{isGenerating ? (
{isGenerating || isContinuousGenerating ? (
<div className="flex gap-3">
<Button
onClick={() => selectedTool === 'generate' ? cancelGeneration() : cancelEdit()}
onClick={() => selectedTool === 'generate' ? cancelContinuousGeneration() : cancelEdit()}
className="flex-1 h-14 text-base font-semibold bg-red-500 hover:bg-red-600 rounded-xl card"
>
<div className="animate-spin rounded-full h-5 w-5 border-b-2 border-white mr-2" />
</Button>
{isContinuousGenerating && (
<div className="flex items-center justify-center bg-yellow-100 text-yellow-800 rounded-lg px-3 py-2 text-sm font-medium">
<span>: {retryCount}</span>
</div>
)}
</div>
) : (
<Button
onClick={handleGenerate}
disabled={!currentPrompt.trim()}
className="w-full h-14 text-base font-semibold rounded-xl shadow-md hover:shadow-lg transition-all card"
>
<Wand2 className="h-5 w-5 mr-2" />
{selectedTool === 'generate' ? '生成图像' : '应用编辑'}
</Button>
<div className="flex gap-2">
<Button
onClick={handleGenerate}
disabled={!currentPrompt.trim()}
className="flex-1 h-14 text-base font-semibold rounded-xl shadow-md hover:shadow-lg transition-all card"
>
<Wand2 className="h-5 w-5 mr-2" />
{selectedTool === 'generate' ? '生成图像' : '应用编辑'}
</Button>
{selectedTool === 'generate' && (
<Button
onClick={handleContinuousGenerate}
disabled={!currentPrompt.trim()}
className="h-14 px-3 text-sm font-semibold rounded-xl shadow-md hover:shadow-lg transition-all card bg-purple-500 hover:bg-purple-600"
title="连续生成直到成功"
>
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" className="mr-1">
<path d="M17 3a2.85 2.83 0 1 1 4 4L7.5 20.5 2 22l1.5-5.5Z"></path>
</svg>
</Button>
)}
</div>
)}
</div>

View File

@@ -1,6 +1,11 @@
import React from 'react';
import { cva, type VariantProps } from 'class-variance-authority';
import { cn } from '../../utils/cn';
const textareaVariants = cva(
'flex min-h-[80px] w-full rounded-lg border border-gray-300 bg-white px-3 py-2 text-sm text-gray-900 ring-offset-white placeholder:text-gray-400 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-yellow-400 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 resize-none'
);
export interface TextareaProps
extends React.TextareaHTMLAttributes<HTMLTextAreaElement>,
VariantProps<typeof textareaVariants> {
@@ -11,10 +16,7 @@ export const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
({ className, ...props }, ref) => {
return (
<textarea
className={cn(
'flex min-h-[80px] w-full rounded-lg border border-gray-300 bg-white px-3 py-2 text-sm text-gray-900 ring-offset-white placeholder:text-gray-400 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-yellow-400 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 resize-none',
className
)}
className={cn(textareaVariants(), className)}
ref={ref}
{...props}
/>