新增 生成过程中可以中断;

新增 生成结果上传到OSS;
新增 历史记录使用上传后的图片;
This commit is contained in:
yuantao
2025-09-15 18:30:50 +08:00
parent e325d0fc8d
commit bda049fcd1
7 changed files with 443 additions and 149 deletions

View File

@@ -1,23 +1,37 @@
import { useMutation } from '@tanstack/react-query';
import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService';
import { useAppStore } from '../store/useAppStore';
import { generateId } from '../utils/imageUtils';
import { Generation, Edit, Asset } from '../types';
import { useToast } from '../components/ToastContext';
import React from 'react'
import { useMutation } from '@tanstack/react-query'
import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService'
import { useAppStore } from '../store/useAppStore'
import { generateId } from '../utils/imageUtils'
import { Generation, Edit, Asset } from '../types'
import { useToast } from '../components/ToastContext'
import { uploadImages } from '../services/uploadService'
export const useImageGeneration = () => {
const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore();
const { addToast } = useToast();
const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore()
const { addToast } = useToast()
// 创建中断标志引用
const isCancelledRef = React.useRef(false)
const generateMutation = useMutation({
mutationFn: async (request: GenerationRequest) => {
const images = await geminiService.generateImage(request);
return images;
// 重置中断标志
isCancelledRef.current = false
const images = await geminiService.generateImage(request)
// 检查是否已中断
if (isCancelledRef.current) {
throw new Error('生成已中断')
}
return images
},
onMutate: () => {
setIsGenerating(true);
setIsGenerating(true)
},
onSuccess: (images, request) => {
onSuccess: async (images, request) => {
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
@@ -29,6 +43,34 @@ export const useImageGeneration = () => {
checksum: base64.slice(0, 32) // 简单校验和
}));
// 获取accessToken
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
let uploadResults: any[] | undefined;
// 上传生成的图像
if (accessToken) {
try {
const imageUrls = outputAssets.map(asset => asset.url);
uploadResults = await uploadImages(imageUrls, accessToken);
// 检查上传结果
const failedUploads = uploadResults.filter(r => !r.success);
if (failedUploads.length > 0) {
console.warn(`${failedUploads.length}张图像上传失败`);
addToast(`${failedUploads.length}张图像上传失败`, 'warning', 5000);
} else {
console.log(`${uploadResults.length}张图像全部上传成功`);
addToast('图像已成功上传', 'success', 3000);
}
} catch (error) {
console.error('上传图像时出错:', error);
addToast('图像上传失败', 'error', 5000);
uploadResults = undefined;
}
} else {
console.warn('未找到accessToken跳过上传');
}
const generation: Generation = {
id: generateId(),
prompt: request.prompt,
@@ -48,7 +90,8 @@ export const useImageGeneration = () => {
})) : [],
outputAssets,
modelVersion: 'gemini-2.5-flash-image-preview',
timestamp: Date.now()
timestamp: Date.now(),
uploadResults: uploadResults
};
addGeneration(generation);
@@ -56,154 +99,158 @@ export const useImageGeneration = () => {
}
setIsGenerating(false);
},
onError: (error) => {
console.error('生成失败:', error);
const errorMessage = error instanceof Error ? error.message : '未知错误';
const errorDetails = error instanceof Error ? error.stack : undefined;
addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails);
setIsGenerating(false);
}
});
onError: error => {
console.error('生成失败:', error)
const errorMessage = error instanceof Error ? error.message : '未知错误'
const errorDetails = error instanceof Error ? error.stack : undefined
addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails)
setIsGenerating(false)
},
})
const cancelGeneration = () => {
isCancelledRef.current = true
setIsGenerating(false)
addToast('生成已中断', 'info', 3000)
}
return {
generate: generateMutation.mutate,
isGenerating: generateMutation.isPending,
error: generateMutation.error
};
};
error: generateMutation.error,
cancelGeneration,
}
}
export const useImageEditing = () => {
const {
addEdit,
setIsGenerating,
setCanvasImage,
canvasImage,
editReferenceImages,
brushStrokes,
selectedGenerationId,
seed,
temperature,
uploadedImages
} = useAppStore();
const { addEdit, setIsGenerating, setCanvasImage, canvasImage, editReferenceImages, brushStrokes, selectedGenerationId, seed, temperature, uploadedImages } = useAppStore()
const { addToast } = useToast()
const { addToast } = useToast();
// 创建中断标志引用
const isCancelledRef = React.useRef(false)
const editMutation = useMutation({
mutationFn: async (instruction: string) => {
// 重置中断标志
isCancelledRef.current = false
// 如果可用,始终使用画布图像作为主要目标,否则使用第一张上传的图像
const sourceImage = canvasImage || uploadedImages[0];
if (!sourceImage) throw new Error('没有要编辑的图像');
const sourceImage = canvasImage || uploadedImages[0]
if (!sourceImage) throw new Error('没有要编辑的图像')
// 将画布图像转换为base64
const base64Image = sourceImage.includes('base64,')
? sourceImage.split('base64,')[1]
: sourceImage;
const base64Image = sourceImage.includes('base64,') ? sourceImage.split('base64,')[1] : sourceImage
// 获取用于样式指导的参考图像
let referenceImages = editReferenceImages
.filter(img => img.includes('base64,'))
.map(img => img.split('base64,')[1]);
let maskImage: string | undefined;
let maskedReferenceImage: string | undefined;
let referenceImages = editReferenceImages.filter(img => img.includes('base64,')).map(img => img.split('base64,')[1])
let maskImage: string | undefined
let maskedReferenceImage: string | undefined
// 如果存在画笔描边,则从描边创建遮罩
if (brushStrokes.length > 0) {
// 创建临时图像以获取实际尺寸
const tempImg = new Image();
tempImg.src = sourceImage;
await new Promise<void>((resolve) => {
tempImg.onload = () => resolve();
});
const tempImg = new Image()
tempImg.src = sourceImage
await new Promise<void>(resolve => {
tempImg.onload = () => resolve()
})
// 创建具有确切图像尺寸的遮罩画布
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d')!;
canvas.width = tempImg.width;
canvas.height = tempImg.height;
const canvas = document.createElement('canvas')
const ctx = canvas.getContext('2d')!
canvas.width = tempImg.width
canvas.height = tempImg.height
// 用黑色填充(未遮罩区域)
ctx.fillStyle = 'black';
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.fillStyle = 'black'
ctx.fillRect(0, 0, canvas.width, canvas.height)
// 绘制白色描边(遮罩区域)
ctx.strokeStyle = 'white';
ctx.lineCap = 'round';
ctx.lineJoin = 'round';
ctx.strokeStyle = 'white'
ctx.lineCap = 'round'
ctx.lineJoin = 'round'
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
ctx.lineWidth = stroke.brushSize;
ctx.beginPath();
ctx.moveTo(stroke.points[0], stroke.points[1]);
ctx.lineWidth = stroke.brushSize
ctx.beginPath()
ctx.moveTo(stroke.points[0], stroke.points[1])
for (let i = 2; i < stroke.points.length; i += 2) {
ctx.lineTo(stroke.points[i], stroke.points[i + 1]);
ctx.lineTo(stroke.points[i], stroke.points[i + 1])
}
ctx.stroke();
ctx.stroke()
}
});
})
// 将遮罩转换为base64
const maskDataUrl = canvas.toDataURL('image/png');
maskImage = maskDataUrl.split('base64,')[1];
const maskDataUrl = canvas.toDataURL('image/png')
maskImage = maskDataUrl.split('base64,')[1]
// 创建遮罩参考图像(带遮罩叠加的原始图像)
const maskedCanvas = document.createElement('canvas');
const maskedCtx = maskedCanvas.getContext('2d')!;
maskedCanvas.width = tempImg.width;
maskedCanvas.height = tempImg.height;
const maskedCanvas = document.createElement('canvas')
const maskedCtx = maskedCanvas.getContext('2d')!
maskedCanvas.width = tempImg.width
maskedCanvas.height = tempImg.height
// 绘制原始图像
maskedCtx.drawImage(tempImg, 0, 0);
maskedCtx.drawImage(tempImg, 0, 0)
// 绘制带透明度的遮罩叠加
maskedCtx.globalCompositeOperation = 'source-over';
maskedCtx.globalAlpha = 0.4;
maskedCtx.fillStyle = '#A855F7';
maskedCtx.globalCompositeOperation = 'source-over'
maskedCtx.globalAlpha = 0.4
maskedCtx.fillStyle = '#A855F7'
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
maskedCtx.lineWidth = stroke.brushSize;
maskedCtx.strokeStyle = '#A855F7';
maskedCtx.lineCap = 'round';
maskedCtx.lineJoin = 'round';
maskedCtx.beginPath();
maskedCtx.moveTo(stroke.points[0], stroke.points[1]);
maskedCtx.lineWidth = stroke.brushSize
maskedCtx.strokeStyle = '#A855F7'
maskedCtx.lineCap = 'round'
maskedCtx.lineJoin = 'round'
maskedCtx.beginPath()
maskedCtx.moveTo(stroke.points[0], stroke.points[1])
for (let i = 2; i < stroke.points.length; i += 2) {
maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1]);
maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1])
}
maskedCtx.stroke();
maskedCtx.stroke()
}
});
maskedCtx.globalAlpha = 1;
maskedCtx.globalCompositeOperation = 'source-over';
const maskedDataUrl = maskedCanvas.toDataURL('image/png');
maskedReferenceImage = maskedDataUrl.split('base64,')[1];
})
maskedCtx.globalAlpha = 1
maskedCtx.globalCompositeOperation = 'source-over'
const maskedDataUrl = maskedCanvas.toDataURL('image/png')
maskedReferenceImage = maskedDataUrl.split('base64,')[1]
// 将遮罩图像作为参考添加到模型中
referenceImages = [maskedReferenceImage, ...referenceImages];
referenceImages = [maskedReferenceImage, ...referenceImages]
}
const request: EditRequest = {
instruction,
originalImage: base64Image,
referenceImages: referenceImages.length > 0 ? referenceImages : undefined,
maskImage,
temperature,
seed
};
seed,
}
const images = await geminiService.editImage(request)
const images = await geminiService.editImage(request);
return { images, maskedReferenceImage };
// 检查是否已中断
if (isCancelledRef.current) {
throw new Error('编辑已中断')
}
return { images, maskedReferenceImage }
},
onMutate: () => {
setIsGenerating(true);
setIsGenerating(true)
},
onSuccess: ({ images, maskedReferenceImage }, instruction) => {
onSuccess: async ({ images, maskedReferenceImage }, instruction) => {
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
@@ -226,6 +273,34 @@ export const useImageEditing = () => {
checksum: maskedReferenceImage.slice(0, 32)
} : undefined;
// 获取accessToken
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
let uploadResults: any[] | undefined;
// 上传编辑后的图像
if (accessToken) {
try {
const imageUrls = outputAssets.map(asset => asset.url);
uploadResults = await uploadImages(imageUrls, accessToken);
// 检查上传结果
const failedUploads = uploadResults.filter(r => !r.success);
if (failedUploads.length > 0) {
console.warn(`${failedUploads.length}张编辑后的图像上传失败`);
addToast(`${failedUploads.length}张编辑后的图像上传失败`, 'warning', 5000);
} else {
console.log(`${uploadResults.length}张编辑后的图像全部上传成功`);
addToast('编辑后的图像已成功上传', 'success', 3000);
}
} catch (error) {
console.error('上传编辑后的图像时出错:', error);
addToast('编辑后的图像上传失败', 'error', 5000);
uploadResults = undefined;
}
} else {
console.warn('未找到accessToken跳过上传');
}
const edit: Edit = {
id: generateId(),
parentGenerationId: selectedGenerationId || '',
@@ -233,7 +308,8 @@ export const useImageEditing = () => {
maskReferenceAsset,
instruction,
outputAssets,
timestamp: Date.now()
timestamp: Date.now(),
uploadResults: uploadResults
};
addEdit(edit);
@@ -246,18 +322,25 @@ export const useImageEditing = () => {
}
setIsGenerating(false);
},
onError: (error) => {
console.error('编辑失败:', error);
const errorMessage = error instanceof Error ? error.message : '未知错误';
const errorDetails = error instanceof Error ? error.stack : undefined;
addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails);
setIsGenerating(false);
}
});
onError: error => {
console.error('编辑失败:', error)
const errorMessage = error instanceof Error ? error.message : '未知错误'
const errorDetails = error instanceof Error ? error.stack : undefined
addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails)
setIsGenerating(false)
},
})
const cancelEdit = () => {
isCancelledRef.current = true
setIsGenerating(false)
addToast('编辑已中断', 'info', 3000)
}
return {
edit: editMutation.mutate,
isEditing: editMutation.isPending,
error: editMutation.error
};
};
error: editMutation.error,
cancelEdit,
}
}