Files
Nano-Banana-AI-Image-Editor/src/hooks/useImageGeneration.ts
2025-09-16 18:38:02 +08:00

376 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import React from 'react'
import { useMutation } from '@tanstack/react-query'
import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService'
import { useAppStore } from '../store/useAppStore'
import { generateId } from '../utils/imageUtils'
import { Generation, Edit, Asset } from '../types'
import { useToast } from '../components/ToastContext'
import { uploadImages } from '../services/uploadService'
export const useImageGeneration = () => {
const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore()
const { addToast } = useToast()
// 创建中断标志引用
const isCancelledRef = React.useRef(false)
const generateMutation = useMutation({
mutationFn: async (request: GenerationRequest) => {
// 重置中断标志
isCancelledRef.current = false
const result = await geminiService.generateImage(request)
// 检查是否已中断
if (isCancelledRef.current) {
throw new Error('生成已中断')
}
return result
},
onMutate: () => {
setIsGenerating(true)
},
onSuccess: async (result, request) => {
const { images, usageMetadata } = result;
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
type: 'output',
url: `data:image/png;base64,${base64}`,
mime: 'image/png',
width: 1024, // 默认Gemini输出尺寸
height: 1024,
checksum: base64.slice(0, 32) // 简单校验和
}));
// 获取accessToken
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
let uploadResults: any[] | undefined;
// 上传生成的图像和参考图像
if (accessToken) {
try {
// 上传生成的图像
const imageUrls = outputAssets.map(asset => asset.url);
const outputUploadResults = await uploadImages(imageUrls, accessToken);
// 上传参考图像(如果存在)
let referenceUploadResults: any[] = [];
if (request.referenceImages && request.referenceImages.length > 0) {
const referenceUrls = request.referenceImages.map(img => `data:image/png;base64,${img}`);
referenceUploadResults = await uploadImages(referenceUrls, accessToken);
}
// 合并上传结果
uploadResults = [...outputUploadResults, ...referenceUploadResults];
// 检查上传结果
const failedUploads = uploadResults.filter(r => !r.success);
if (failedUploads.length > 0) {
console.warn(`${failedUploads.length}张图像上传失败`);
addToast(`${failedUploads.length}张图像上传失败`, 'warning', 5000);
} else {
console.log(`${uploadResults.length}张图像全部上传成功`);
addToast('图像已成功上传', 'success', 3000);
}
} catch (error) {
console.error('上传图像时出错:', error);
addToast('图像上传失败', 'error', 5000);
uploadResults = undefined;
}
} else {
console.warn('未找到accessToken跳过上传');
}
// 显示Token消耗信息如果可用
if (usageMetadata?.totalTokenCount) {
addToast(`本次生成消耗 ${usageMetadata.totalTokenCount} Tokens`, 'info', 3000);
}
const generation: Generation = {
id: generateId(),
prompt: request.prompt,
parameters: {
aspectRatio: '1:1',
seed: request.seed,
temperature: request.temperature
},
sourceAssets: request.referenceImages ? request.referenceImages.map((img, index) => ({
id: generateId(),
type: 'original' as const,
url: `data:image/png;base64,${img}`,
mime: 'image/png',
width: 1024,
height: 1024,
checksum: img.slice(0, 32)
})) : [],
outputAssets,
modelVersion: 'gemini-2.5-flash-image-preview',
timestamp: Date.now(),
uploadResults: uploadResults,
usageMetadata: usageMetadata // 保存usageMetadata到历史记录
};
addGeneration(generation);
setCanvasImage(outputAssets[0].url);
}
setIsGenerating(false);
},
onError: error => {
console.error('生成失败:', error)
const errorMessage = error instanceof Error ? error.message : '未知错误'
const errorDetails = error instanceof Error ? error.stack : undefined
addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails)
setIsGenerating(false)
},
})
const cancelGeneration = () => {
isCancelledRef.current = true
setIsGenerating(false)
addToast('生成已中断', 'info', 3000)
}
return {
generate: generateMutation.mutate,
isGenerating: generateMutation.isPending,
error: generateMutation.error,
cancelGeneration,
}
}
export const useImageEditing = () => {
const { addEdit, setIsGenerating, setCanvasImage, canvasImage, editReferenceImages, brushStrokes, selectedGenerationId, seed, temperature, uploadedImages } = useAppStore()
const { addToast } = useToast()
// 创建中断标志引用
const isCancelledRef = React.useRef(false)
const editMutation = useMutation({
mutationFn: async (instruction: string) => {
// 重置中断标志
isCancelledRef.current = false
// 如果可用,始终使用画布图像作为主要目标,否则使用第一张上传的图像
const sourceImage = canvasImage || uploadedImages[0]
if (!sourceImage) throw new Error('没有要编辑的图像')
// 将画布图像转换为base64
const base64Image = sourceImage.includes('base64,') ? sourceImage.split('base64,')[1] : sourceImage
// 获取用于样式指导的参考图像
let referenceImages = editReferenceImages.filter(img => img.includes('base64,')).map(img => img.split('base64,')[1])
let maskImage: string | undefined
let maskedReferenceImage: string | undefined
// 如果存在画笔描边,则从描边创建遮罩
if (brushStrokes.length > 0) {
// 创建临时图像以获取实际尺寸
const tempImg = new Image()
tempImg.src = sourceImage
await new Promise<void>(resolve => {
tempImg.onload = () => resolve()
})
// 创建具有确切图像尺寸的遮罩画布
const canvas = document.createElement('canvas')
const ctx = canvas.getContext('2d')!
canvas.width = tempImg.width
canvas.height = tempImg.height
// 用黑色填充(未遮罩区域)
ctx.fillStyle = 'black'
ctx.fillRect(0, 0, canvas.width, canvas.height)
// 绘制白色描边(遮罩区域)
ctx.strokeStyle = 'white'
ctx.lineCap = 'round'
ctx.lineJoin = 'round'
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
ctx.lineWidth = stroke.brushSize
ctx.beginPath()
ctx.moveTo(stroke.points[0], stroke.points[1])
for (let i = 2; i < stroke.points.length; i += 2) {
ctx.lineTo(stroke.points[i], stroke.points[i + 1])
}
ctx.stroke()
}
})
// 将遮罩转换为base64
const maskDataUrl = canvas.toDataURL('image/png')
maskImage = maskDataUrl.split('base64,')[1]
// 创建遮罩参考图像(带遮罩叠加的原始图像)
const maskedCanvas = document.createElement('canvas')
const maskedCtx = maskedCanvas.getContext('2d')!
maskedCanvas.width = tempImg.width
maskedCanvas.height = tempImg.height
// 绘制原始图像
maskedCtx.drawImage(tempImg, 0, 0)
// 绘制带透明度的遮罩叠加
maskedCtx.globalCompositeOperation = 'source-over'
maskedCtx.globalAlpha = 0.4
maskedCtx.fillStyle = '#A855F7'
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
maskedCtx.lineWidth = stroke.brushSize
maskedCtx.strokeStyle = '#A855F7'
maskedCtx.lineCap = 'round'
maskedCtx.lineJoin = 'round'
maskedCtx.beginPath()
maskedCtx.moveTo(stroke.points[0], stroke.points[1])
for (let i = 2; i < stroke.points.length; i += 2) {
maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1])
}
maskedCtx.stroke()
}
})
maskedCtx.globalAlpha = 1
maskedCtx.globalCompositeOperation = 'source-over'
const maskedDataUrl = maskedCanvas.toDataURL('image/png')
maskedReferenceImage = maskedDataUrl.split('base64,')[1]
// 将遮罩图像作为参考添加到模型中
referenceImages = [maskedReferenceImage, ...referenceImages]
}
const request: EditRequest = {
instruction,
originalImage: base64Image,
referenceImages: referenceImages.length > 0 ? referenceImages : undefined,
maskImage,
temperature,
seed,
}
const result = await geminiService.editImage(request)
// 检查是否已中断
if (isCancelledRef.current) {
throw new Error('编辑已中断')
}
return { result, maskedReferenceImage }
},
onMutate: () => {
setIsGenerating(true)
},
onSuccess: async ({ result, maskedReferenceImage }, instruction) => {
const { images, usageMetadata } = result;
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
type: 'output',
url: `data:image/png;base64,${base64}`,
mime: 'image/png',
width: 1024,
height: 1024,
checksum: base64.slice(0, 32)
}));
// 如果有遮罩参考图像则创建遮罩参考资产
const maskReferenceAsset: Asset | undefined = maskedReferenceImage ? {
id: generateId(),
type: 'mask',
url: `data:image/png;base64,${maskedReferenceImage}`,
mime: 'image/png',
width: 1024,
height: 1024,
checksum: maskedReferenceImage.slice(0, 32)
} : undefined;
// 获取accessToken
const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
let uploadResults: any[] | undefined;
// 上传编辑后的图像
if (accessToken) {
try {
const imageUrls = outputAssets.map(asset => asset.url);
uploadResults = await uploadImages(imageUrls, accessToken);
// 检查上传结果
const failedUploads = uploadResults.filter(r => !r.success);
if (failedUploads.length > 0) {
console.warn(`${failedUploads.length}张编辑后的图像上传失败`);
addToast(`${failedUploads.length}张编辑后的图像上传失败`, 'warning', 5000);
} else {
console.log(`${uploadResults.length}张编辑后的图像全部上传成功`);
addToast('编辑后的图像已成功上传', 'success', 3000);
}
} catch (error) {
console.error('上传编辑后的图像时出错:', error);
addToast('编辑后的图像上传失败', 'error', 5000);
uploadResults = undefined;
}
} else {
console.warn('未找到accessToken跳过上传');
}
// 显示Token消耗信息如果可用
if (usageMetadata?.totalTokenCount) {
addToast(`本次编辑消耗 ${usageMetadata.totalTokenCount} Tokens`, 'info', 3000);
}
const edit: Edit = {
id: generateId(),
parentGenerationId: selectedGenerationId || '',
maskAssetId: brushStrokes.length > 0 ? generateId() : undefined,
maskReferenceAsset,
instruction,
outputAssets,
timestamp: Date.now(),
uploadResults: uploadResults,
parameters: {
seed: seed || undefined,
temperature: temperature
},
usageMetadata: usageMetadata // 保存usageMetadata到历史记录
};
addEdit(edit);
// 自动在画布中加载编辑后的图像
const { selectEdit, selectGeneration } = useAppStore.getState();
setCanvasImage(outputAssets[0].url);
selectEdit(edit.id);
selectGeneration(null);
}
setIsGenerating(false);
},
onError: error => {
console.error('编辑失败:', error)
const errorMessage = error instanceof Error ? error.message : '未知错误'
const errorDetails = error instanceof Error ? error.stack : undefined
addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails)
setIsGenerating(false)
},
})
const cancelEdit = () => {
isCancelledRef.current = true
setIsGenerating(false)
addToast('编辑已中断', 'info', 3000)
}
return {
edit: editMutation.mutate,
isEditing: editMutation.isPending,
error: editMutation.error,
cancelEdit,
}
}