You've already forked Nano-Banana-AI-Image-Editor
376 lines
13 KiB
TypeScript
376 lines
13 KiB
TypeScript
import React from 'react'
|
||
import { useMutation } from '@tanstack/react-query'
|
||
import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService'
|
||
import { useAppStore } from '../store/useAppStore'
|
||
import { generateId } from '../utils/imageUtils'
|
||
import { Generation, Edit, Asset } from '../types'
|
||
import { useToast } from '../components/ToastContext'
|
||
import { uploadImages } from '../services/uploadService'
|
||
|
||
export const useImageGeneration = () => {
|
||
const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore()
|
||
const { addToast } = useToast()
|
||
|
||
// 创建中断标志引用
|
||
const isCancelledRef = React.useRef(false)
|
||
|
||
const generateMutation = useMutation({
|
||
mutationFn: async (request: GenerationRequest) => {
|
||
// 重置中断标志
|
||
isCancelledRef.current = false
|
||
|
||
const result = await geminiService.generateImage(request)
|
||
|
||
// 检查是否已中断
|
||
if (isCancelledRef.current) {
|
||
throw new Error('生成已中断')
|
||
}
|
||
|
||
return result
|
||
},
|
||
onMutate: () => {
|
||
setIsGenerating(true)
|
||
},
|
||
onSuccess: async (result, request) => {
|
||
const { images, usageMetadata } = result;
|
||
if (images.length > 0) {
|
||
const outputAssets: Asset[] = images.map((base64, index) => ({
|
||
id: generateId(),
|
||
type: 'output',
|
||
url: `data:image/png;base64,${base64}`,
|
||
mime: 'image/png',
|
||
width: 1024, // 默认Gemini输出尺寸
|
||
height: 1024,
|
||
checksum: base64.slice(0, 32) // 简单校验和
|
||
}));
|
||
|
||
// 获取accessToken
|
||
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
|
||
let uploadResults: any[] | undefined;
|
||
|
||
// 上传生成的图像和参考图像
|
||
if (accessToken) {
|
||
try {
|
||
// 上传生成的图像
|
||
const imageUrls = outputAssets.map(asset => asset.url);
|
||
const outputUploadResults = await uploadImages(imageUrls, accessToken);
|
||
|
||
// 上传参考图像(如果存在)
|
||
let referenceUploadResults: any[] = [];
|
||
if (request.referenceImages && request.referenceImages.length > 0) {
|
||
const referenceUrls = request.referenceImages.map(img => `data:image/png;base64,${img}`);
|
||
referenceUploadResults = await uploadImages(referenceUrls, accessToken);
|
||
}
|
||
|
||
// 合并上传结果
|
||
uploadResults = [...outputUploadResults, ...referenceUploadResults];
|
||
|
||
// 检查上传结果
|
||
const failedUploads = uploadResults.filter(r => !r.success);
|
||
if (failedUploads.length > 0) {
|
||
console.warn(`${failedUploads.length}张图像上传失败`);
|
||
addToast(`${failedUploads.length}张图像上传失败`, 'warning', 5000);
|
||
} else {
|
||
console.log(`${uploadResults.length}张图像全部上传成功`);
|
||
addToast('图像已成功上传', 'success', 3000);
|
||
}
|
||
} catch (error) {
|
||
console.error('上传图像时出错:', error);
|
||
addToast('图像上传失败', 'error', 5000);
|
||
uploadResults = undefined;
|
||
}
|
||
} else {
|
||
console.warn('未找到accessToken,跳过上传');
|
||
}
|
||
|
||
// 显示Token消耗信息(如果可用)
|
||
if (usageMetadata?.totalTokenCount) {
|
||
addToast(`本次生成消耗 ${usageMetadata.totalTokenCount} Tokens`, 'info', 3000);
|
||
}
|
||
|
||
const generation: Generation = {
|
||
id: generateId(),
|
||
prompt: request.prompt,
|
||
parameters: {
|
||
aspectRatio: '1:1',
|
||
seed: request.seed,
|
||
temperature: request.temperature
|
||
},
|
||
sourceAssets: request.referenceImages ? request.referenceImages.map((img, index) => ({
|
||
id: generateId(),
|
||
type: 'original' as const,
|
||
url: `data:image/png;base64,${img}`,
|
||
mime: 'image/png',
|
||
width: 1024,
|
||
height: 1024,
|
||
checksum: img.slice(0, 32)
|
||
})) : [],
|
||
outputAssets,
|
||
modelVersion: 'gemini-2.5-flash-image-preview',
|
||
timestamp: Date.now(),
|
||
uploadResults: uploadResults,
|
||
usageMetadata: usageMetadata // 保存usageMetadata到历史记录
|
||
};
|
||
|
||
addGeneration(generation);
|
||
setCanvasImage(outputAssets[0].url);
|
||
}
|
||
setIsGenerating(false);
|
||
},
|
||
onError: error => {
|
||
console.error('生成失败:', error)
|
||
const errorMessage = error instanceof Error ? error.message : '未知错误'
|
||
const errorDetails = error instanceof Error ? error.stack : undefined
|
||
addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails)
|
||
setIsGenerating(false)
|
||
},
|
||
})
|
||
|
||
const cancelGeneration = () => {
|
||
isCancelledRef.current = true
|
||
setIsGenerating(false)
|
||
addToast('生成已中断', 'info', 3000)
|
||
}
|
||
|
||
return {
|
||
generate: generateMutation.mutate,
|
||
isGenerating: generateMutation.isPending,
|
||
error: generateMutation.error,
|
||
cancelGeneration,
|
||
}
|
||
}
|
||
|
||
export const useImageEditing = () => {
|
||
const { addEdit, setIsGenerating, setCanvasImage, canvasImage, editReferenceImages, brushStrokes, selectedGenerationId, seed, temperature, uploadedImages } = useAppStore()
|
||
|
||
const { addToast } = useToast()
|
||
|
||
// 创建中断标志引用
|
||
const isCancelledRef = React.useRef(false)
|
||
|
||
const editMutation = useMutation({
|
||
mutationFn: async (instruction: string) => {
|
||
// 重置中断标志
|
||
isCancelledRef.current = false
|
||
|
||
// 如果可用,始终使用画布图像作为主要目标,否则使用第一张上传的图像
|
||
const sourceImage = canvasImage || uploadedImages[0]
|
||
if (!sourceImage) throw new Error('没有要编辑的图像')
|
||
|
||
// 将画布图像转换为base64
|
||
const base64Image = sourceImage.includes('base64,') ? sourceImage.split('base64,')[1] : sourceImage
|
||
|
||
// 获取用于样式指导的参考图像
|
||
let referenceImages = editReferenceImages.filter(img => img.includes('base64,')).map(img => img.split('base64,')[1])
|
||
|
||
let maskImage: string | undefined
|
||
let maskedReferenceImage: string | undefined
|
||
|
||
// 如果存在画笔描边,则从描边创建遮罩
|
||
if (brushStrokes.length > 0) {
|
||
// 创建临时图像以获取实际尺寸
|
||
const tempImg = new Image()
|
||
tempImg.src = sourceImage
|
||
await new Promise<void>(resolve => {
|
||
tempImg.onload = () => resolve()
|
||
})
|
||
|
||
// 创建具有确切图像尺寸的遮罩画布
|
||
const canvas = document.createElement('canvas')
|
||
const ctx = canvas.getContext('2d')!
|
||
canvas.width = tempImg.width
|
||
canvas.height = tempImg.height
|
||
|
||
// 用黑色填充(未遮罩区域)
|
||
ctx.fillStyle = 'black'
|
||
ctx.fillRect(0, 0, canvas.width, canvas.height)
|
||
|
||
// 绘制白色描边(遮罩区域)
|
||
ctx.strokeStyle = 'white'
|
||
ctx.lineCap = 'round'
|
||
ctx.lineJoin = 'round'
|
||
|
||
brushStrokes.forEach(stroke => {
|
||
if (stroke.points.length >= 4) {
|
||
ctx.lineWidth = stroke.brushSize
|
||
ctx.beginPath()
|
||
ctx.moveTo(stroke.points[0], stroke.points[1])
|
||
|
||
for (let i = 2; i < stroke.points.length; i += 2) {
|
||
ctx.lineTo(stroke.points[i], stroke.points[i + 1])
|
||
}
|
||
ctx.stroke()
|
||
}
|
||
})
|
||
|
||
// 将遮罩转换为base64
|
||
const maskDataUrl = canvas.toDataURL('image/png')
|
||
maskImage = maskDataUrl.split('base64,')[1]
|
||
|
||
// 创建遮罩参考图像(带遮罩叠加的原始图像)
|
||
const maskedCanvas = document.createElement('canvas')
|
||
const maskedCtx = maskedCanvas.getContext('2d')!
|
||
maskedCanvas.width = tempImg.width
|
||
maskedCanvas.height = tempImg.height
|
||
|
||
// 绘制原始图像
|
||
maskedCtx.drawImage(tempImg, 0, 0)
|
||
|
||
// 绘制带透明度的遮罩叠加
|
||
maskedCtx.globalCompositeOperation = 'source-over'
|
||
maskedCtx.globalAlpha = 0.4
|
||
maskedCtx.fillStyle = '#A855F7'
|
||
|
||
brushStrokes.forEach(stroke => {
|
||
if (stroke.points.length >= 4) {
|
||
maskedCtx.lineWidth = stroke.brushSize
|
||
maskedCtx.strokeStyle = '#A855F7'
|
||
maskedCtx.lineCap = 'round'
|
||
maskedCtx.lineJoin = 'round'
|
||
maskedCtx.beginPath()
|
||
maskedCtx.moveTo(stroke.points[0], stroke.points[1])
|
||
|
||
for (let i = 2; i < stroke.points.length; i += 2) {
|
||
maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1])
|
||
}
|
||
maskedCtx.stroke()
|
||
}
|
||
})
|
||
|
||
maskedCtx.globalAlpha = 1
|
||
maskedCtx.globalCompositeOperation = 'source-over'
|
||
|
||
const maskedDataUrl = maskedCanvas.toDataURL('image/png')
|
||
maskedReferenceImage = maskedDataUrl.split('base64,')[1]
|
||
|
||
// 将遮罩图像作为参考添加到模型中
|
||
referenceImages = [maskedReferenceImage, ...referenceImages]
|
||
}
|
||
|
||
const request: EditRequest = {
|
||
instruction,
|
||
originalImage: base64Image,
|
||
referenceImages: referenceImages.length > 0 ? referenceImages : undefined,
|
||
maskImage,
|
||
temperature,
|
||
seed,
|
||
}
|
||
|
||
const result = await geminiService.editImage(request)
|
||
|
||
// 检查是否已中断
|
||
if (isCancelledRef.current) {
|
||
throw new Error('编辑已中断')
|
||
}
|
||
|
||
return { result, maskedReferenceImage }
|
||
},
|
||
onMutate: () => {
|
||
setIsGenerating(true)
|
||
},
|
||
onSuccess: async ({ result, maskedReferenceImage }, instruction) => {
|
||
const { images, usageMetadata } = result;
|
||
if (images.length > 0) {
|
||
const outputAssets: Asset[] = images.map((base64, index) => ({
|
||
id: generateId(),
|
||
type: 'output',
|
||
url: `data:image/png;base64,${base64}`,
|
||
mime: 'image/png',
|
||
width: 1024,
|
||
height: 1024,
|
||
checksum: base64.slice(0, 32)
|
||
}));
|
||
|
||
// 如果有遮罩参考图像则创建遮罩参考资产
|
||
const maskReferenceAsset: Asset | undefined = maskedReferenceImage ? {
|
||
id: generateId(),
|
||
type: 'mask',
|
||
url: `data:image/png;base64,${maskedReferenceImage}`,
|
||
mime: 'image/png',
|
||
width: 1024,
|
||
height: 1024,
|
||
checksum: maskedReferenceImage.slice(0, 32)
|
||
} : undefined;
|
||
|
||
// 获取accessToken
|
||
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
|
||
let uploadResults: any[] | undefined;
|
||
|
||
// 上传编辑后的图像
|
||
if (accessToken) {
|
||
try {
|
||
const imageUrls = outputAssets.map(asset => asset.url);
|
||
uploadResults = await uploadImages(imageUrls, accessToken);
|
||
|
||
// 检查上传结果
|
||
const failedUploads = uploadResults.filter(r => !r.success);
|
||
if (failedUploads.length > 0) {
|
||
console.warn(`${failedUploads.length}张编辑后的图像上传失败`);
|
||
addToast(`${failedUploads.length}张编辑后的图像上传失败`, 'warning', 5000);
|
||
} else {
|
||
console.log(`${uploadResults.length}张编辑后的图像全部上传成功`);
|
||
addToast('编辑后的图像已成功上传', 'success', 3000);
|
||
}
|
||
} catch (error) {
|
||
console.error('上传编辑后的图像时出错:', error);
|
||
addToast('编辑后的图像上传失败', 'error', 5000);
|
||
uploadResults = undefined;
|
||
}
|
||
} else {
|
||
console.warn('未找到accessToken,跳过上传');
|
||
}
|
||
|
||
// 显示Token消耗信息(如果可用)
|
||
if (usageMetadata?.totalTokenCount) {
|
||
addToast(`本次编辑消耗 ${usageMetadata.totalTokenCount} Tokens`, 'info', 3000);
|
||
}
|
||
|
||
const edit: Edit = {
|
||
id: generateId(),
|
||
parentGenerationId: selectedGenerationId || '',
|
||
maskAssetId: brushStrokes.length > 0 ? generateId() : undefined,
|
||
maskReferenceAsset,
|
||
instruction,
|
||
outputAssets,
|
||
timestamp: Date.now(),
|
||
uploadResults: uploadResults,
|
||
parameters: {
|
||
seed: seed || undefined,
|
||
temperature: temperature
|
||
},
|
||
usageMetadata: usageMetadata // 保存usageMetadata到历史记录
|
||
};
|
||
|
||
addEdit(edit);
|
||
|
||
// 自动在画布中加载编辑后的图像
|
||
const { selectEdit, selectGeneration } = useAppStore.getState();
|
||
setCanvasImage(outputAssets[0].url);
|
||
selectEdit(edit.id);
|
||
selectGeneration(null);
|
||
}
|
||
setIsGenerating(false);
|
||
},
|
||
onError: error => {
|
||
console.error('编辑失败:', error)
|
||
const errorMessage = error instanceof Error ? error.message : '未知错误'
|
||
const errorDetails = error instanceof Error ? error.stack : undefined
|
||
addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails)
|
||
setIsGenerating(false)
|
||
},
|
||
})
|
||
|
||
const cancelEdit = () => {
|
||
isCancelledRef.current = true
|
||
setIsGenerating(false)
|
||
addToast('编辑已中断', 'info', 3000)
|
||
}
|
||
|
||
return {
|
||
edit: editMutation.mutate,
|
||
isEditing: editMutation.isPending,
|
||
error: editMutation.error,
|
||
cancelEdit,
|
||
}
|
||
}
|