diff --git a/src/hooks/useImageGeneration.ts b/src/hooks/useImageGeneration.ts
index 1ca7450..2d6ecf1 100644
--- a/src/hooks/useImageGeneration.ts
+++ b/src/hooks/useImageGeneration.ts
@@ -1,23 +1,37 @@
-import { useMutation } from '@tanstack/react-query';
-import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService';
-import { useAppStore } from '../store/useAppStore';
-import { generateId } from '../utils/imageUtils';
-import { Generation, Edit, Asset } from '../types';
-import { useToast } from '../components/ToastContext';
+import React from 'react'
+import { useMutation } from '@tanstack/react-query'
+import { geminiService, GenerationRequest, EditRequest } from '../services/geminiService'
+import { useAppStore } from '../store/useAppStore'
+import { generateId } from '../utils/imageUtils'
+import { Generation, Edit, Asset } from '../types'
+import { useToast } from '../components/ToastContext'
+import { uploadImages } from '../services/uploadService'
export const useImageGeneration = () => {
- const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore();
- const { addToast } = useToast();
+ const { addGeneration, setIsGenerating, setCanvasImage } = useAppStore()
+ const { addToast } = useToast()
+
+ // 创建中断标志引用
+ const isCancelledRef = React.useRef(false)
const generateMutation = useMutation({
mutationFn: async (request: GenerationRequest) => {
- const images = await geminiService.generateImage(request);
- return images;
+ // 重置中断标志
+ isCancelledRef.current = false
+
+ const images = await geminiService.generateImage(request)
+
+ // 检查是否已中断
+ if (isCancelledRef.current) {
+ throw new Error('生成已中断')
+ }
+
+ return images
},
onMutate: () => {
- setIsGenerating(true);
+ setIsGenerating(true)
},
- onSuccess: (images, request) => {
+ onSuccess: async (images, request) => {
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
@@ -29,6 +43,34 @@ export const useImageGeneration = () => {
checksum: base64.slice(0, 32) // 简单校验和
}));
+ // 获取accessToken
+ const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
+ let uploadResults: any[] | undefined;
+
+ // 上传生成的图像
+ if (accessToken) {
+ try {
+ const imageUrls = outputAssets.map(asset => asset.url);
+ uploadResults = await uploadImages(imageUrls, accessToken);
+
+ // 检查上传结果
+ const failedUploads = uploadResults.filter(r => !r.success);
+ if (failedUploads.length > 0) {
+ console.warn(`${failedUploads.length}张图像上传失败`);
+ addToast(`${failedUploads.length}张图像上传失败`, 'warning', 5000);
+ } else {
+ console.log(`${uploadResults.length}张图像全部上传成功`);
+ addToast('图像已成功上传', 'success', 3000);
+ }
+ } catch (error) {
+ console.error('上传图像时出错:', error);
+ addToast('图像上传失败', 'error', 5000);
+ uploadResults = undefined;
+ }
+ } else {
+ console.warn('未找到accessToken,跳过上传');
+ }
+
const generation: Generation = {
id: generateId(),
prompt: request.prompt,
@@ -48,7 +90,8 @@ export const useImageGeneration = () => {
})) : [],
outputAssets,
modelVersion: 'gemini-2.5-flash-image-preview',
- timestamp: Date.now()
+ timestamp: Date.now(),
+ uploadResults: uploadResults
};
addGeneration(generation);
@@ -56,154 +99,158 @@ export const useImageGeneration = () => {
}
setIsGenerating(false);
},
- onError: (error) => {
- console.error('生成失败:', error);
- const errorMessage = error instanceof Error ? error.message : '未知错误';
- const errorDetails = error instanceof Error ? error.stack : undefined;
- addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails);
- setIsGenerating(false);
- }
- });
+ onError: error => {
+ console.error('生成失败:', error)
+ const errorMessage = error instanceof Error ? error.message : '未知错误'
+ const errorDetails = error instanceof Error ? error.stack : undefined
+ addToast(`图像生成失败: ${errorMessage}`, 'error', 5000, errorDetails)
+ setIsGenerating(false)
+ },
+ })
+
+ const cancelGeneration = () => {
+ isCancelledRef.current = true
+ setIsGenerating(false)
+ addToast('生成已中断', 'info', 3000)
+ }
return {
generate: generateMutation.mutate,
isGenerating: generateMutation.isPending,
- error: generateMutation.error
- };
-};
+ error: generateMutation.error,
+ cancelGeneration,
+ }
+}
export const useImageEditing = () => {
- const {
- addEdit,
- setIsGenerating,
- setCanvasImage,
- canvasImage,
- editReferenceImages,
- brushStrokes,
- selectedGenerationId,
- seed,
- temperature,
- uploadedImages
- } = useAppStore();
+ const { addEdit, setIsGenerating, setCanvasImage, canvasImage, editReferenceImages, brushStrokes, selectedGenerationId, seed, temperature, uploadedImages } = useAppStore()
+
+ const { addToast } = useToast()
- const { addToast } = useToast();
+ // 创建中断标志引用
+ const isCancelledRef = React.useRef(false)
const editMutation = useMutation({
mutationFn: async (instruction: string) => {
+ // 重置中断标志
+ isCancelledRef.current = false
+
// 如果可用,始终使用画布图像作为主要目标,否则使用第一张上传的图像
- const sourceImage = canvasImage || uploadedImages[0];
- if (!sourceImage) throw new Error('没有要编辑的图像');
-
+ const sourceImage = canvasImage || uploadedImages[0]
+ if (!sourceImage) throw new Error('没有要编辑的图像')
+
// 将画布图像转换为base64
- const base64Image = sourceImage.includes('base64,')
- ? sourceImage.split('base64,')[1]
- : sourceImage;
-
+ const base64Image = sourceImage.includes('base64,') ? sourceImage.split('base64,')[1] : sourceImage
+
// 获取用于样式指导的参考图像
- let referenceImages = editReferenceImages
- .filter(img => img.includes('base64,'))
- .map(img => img.split('base64,')[1]);
-
- let maskImage: string | undefined;
- let maskedReferenceImage: string | undefined;
-
+ let referenceImages = editReferenceImages.filter(img => img.includes('base64,')).map(img => img.split('base64,')[1])
+
+ let maskImage: string | undefined
+ let maskedReferenceImage: string | undefined
+
// 如果存在画笔描边,则从描边创建遮罩
if (brushStrokes.length > 0) {
// 创建临时图像以获取实际尺寸
- const tempImg = new Image();
- tempImg.src = sourceImage;
- await new Promise
((resolve) => {
- tempImg.onload = () => resolve();
- });
-
+ const tempImg = new Image()
+ tempImg.src = sourceImage
+ await new Promise(resolve => {
+ tempImg.onload = () => resolve()
+ })
+
// 创建具有确切图像尺寸的遮罩画布
- const canvas = document.createElement('canvas');
- const ctx = canvas.getContext('2d')!;
- canvas.width = tempImg.width;
- canvas.height = tempImg.height;
-
+ const canvas = document.createElement('canvas')
+ const ctx = canvas.getContext('2d')!
+ canvas.width = tempImg.width
+ canvas.height = tempImg.height
+
// 用黑色填充(未遮罩区域)
- ctx.fillStyle = 'black';
- ctx.fillRect(0, 0, canvas.width, canvas.height);
-
+ ctx.fillStyle = 'black'
+ ctx.fillRect(0, 0, canvas.width, canvas.height)
+
// 绘制白色描边(遮罩区域)
- ctx.strokeStyle = 'white';
- ctx.lineCap = 'round';
- ctx.lineJoin = 'round';
-
+ ctx.strokeStyle = 'white'
+ ctx.lineCap = 'round'
+ ctx.lineJoin = 'round'
+
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
- ctx.lineWidth = stroke.brushSize;
- ctx.beginPath();
- ctx.moveTo(stroke.points[0], stroke.points[1]);
-
+ ctx.lineWidth = stroke.brushSize
+ ctx.beginPath()
+ ctx.moveTo(stroke.points[0], stroke.points[1])
+
for (let i = 2; i < stroke.points.length; i += 2) {
- ctx.lineTo(stroke.points[i], stroke.points[i + 1]);
+ ctx.lineTo(stroke.points[i], stroke.points[i + 1])
}
- ctx.stroke();
+ ctx.stroke()
}
- });
-
+ })
+
// 将遮罩转换为base64
- const maskDataUrl = canvas.toDataURL('image/png');
- maskImage = maskDataUrl.split('base64,')[1];
-
+ const maskDataUrl = canvas.toDataURL('image/png')
+ maskImage = maskDataUrl.split('base64,')[1]
+
// 创建遮罩参考图像(带遮罩叠加的原始图像)
- const maskedCanvas = document.createElement('canvas');
- const maskedCtx = maskedCanvas.getContext('2d')!;
- maskedCanvas.width = tempImg.width;
- maskedCanvas.height = tempImg.height;
-
+ const maskedCanvas = document.createElement('canvas')
+ const maskedCtx = maskedCanvas.getContext('2d')!
+ maskedCanvas.width = tempImg.width
+ maskedCanvas.height = tempImg.height
+
// 绘制原始图像
- maskedCtx.drawImage(tempImg, 0, 0);
-
+ maskedCtx.drawImage(tempImg, 0, 0)
+
// 绘制带透明度的遮罩叠加
- maskedCtx.globalCompositeOperation = 'source-over';
- maskedCtx.globalAlpha = 0.4;
- maskedCtx.fillStyle = '#A855F7';
-
+ maskedCtx.globalCompositeOperation = 'source-over'
+ maskedCtx.globalAlpha = 0.4
+ maskedCtx.fillStyle = '#A855F7'
+
brushStrokes.forEach(stroke => {
if (stroke.points.length >= 4) {
- maskedCtx.lineWidth = stroke.brushSize;
- maskedCtx.strokeStyle = '#A855F7';
- maskedCtx.lineCap = 'round';
- maskedCtx.lineJoin = 'round';
- maskedCtx.beginPath();
- maskedCtx.moveTo(stroke.points[0], stroke.points[1]);
-
+ maskedCtx.lineWidth = stroke.brushSize
+ maskedCtx.strokeStyle = '#A855F7'
+ maskedCtx.lineCap = 'round'
+ maskedCtx.lineJoin = 'round'
+ maskedCtx.beginPath()
+ maskedCtx.moveTo(stroke.points[0], stroke.points[1])
+
for (let i = 2; i < stroke.points.length; i += 2) {
- maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1]);
+ maskedCtx.lineTo(stroke.points[i], stroke.points[i + 1])
}
- maskedCtx.stroke();
+ maskedCtx.stroke()
}
- });
-
- maskedCtx.globalAlpha = 1;
- maskedCtx.globalCompositeOperation = 'source-over';
-
- const maskedDataUrl = maskedCanvas.toDataURL('image/png');
- maskedReferenceImage = maskedDataUrl.split('base64,')[1];
-
+ })
+
+ maskedCtx.globalAlpha = 1
+ maskedCtx.globalCompositeOperation = 'source-over'
+
+ const maskedDataUrl = maskedCanvas.toDataURL('image/png')
+ maskedReferenceImage = maskedDataUrl.split('base64,')[1]
+
// 将遮罩图像作为参考添加到模型中
- referenceImages = [maskedReferenceImage, ...referenceImages];
+ referenceImages = [maskedReferenceImage, ...referenceImages]
}
-
+
const request: EditRequest = {
instruction,
originalImage: base64Image,
referenceImages: referenceImages.length > 0 ? referenceImages : undefined,
maskImage,
temperature,
- seed
- };
+ seed,
+ }
+
+ const images = await geminiService.editImage(request)
- const images = await geminiService.editImage(request);
- return { images, maskedReferenceImage };
+ // 检查是否已中断
+ if (isCancelledRef.current) {
+ throw new Error('编辑已中断')
+ }
+
+ return { images, maskedReferenceImage }
},
onMutate: () => {
- setIsGenerating(true);
+ setIsGenerating(true)
},
- onSuccess: ({ images, maskedReferenceImage }, instruction) => {
+ onSuccess: async ({ images, maskedReferenceImage }, instruction) => {
if (images.length > 0) {
const outputAssets: Asset[] = images.map((base64, index) => ({
id: generateId(),
@@ -226,6 +273,34 @@ export const useImageEditing = () => {
checksum: maskedReferenceImage.slice(0, 32)
} : undefined;
+ // 获取accessToken
+ const accessToken = import.meta.env.VITE_ACCESS_TOKEN || '';
+ let uploadResults: any[] | undefined;
+
+ // 上传编辑后的图像
+ if (accessToken) {
+ try {
+ const imageUrls = outputAssets.map(asset => asset.url);
+ uploadResults = await uploadImages(imageUrls, accessToken);
+
+ // 检查上传结果
+ const failedUploads = uploadResults.filter(r => !r.success);
+ if (failedUploads.length > 0) {
+ console.warn(`${failedUploads.length}张编辑后的图像上传失败`);
+ addToast(`${failedUploads.length}张编辑后的图像上传失败`, 'warning', 5000);
+ } else {
+ console.log(`${uploadResults.length}张编辑后的图像全部上传成功`);
+ addToast('编辑后的图像已成功上传', 'success', 3000);
+ }
+ } catch (error) {
+ console.error('上传编辑后的图像时出错:', error);
+ addToast('编辑后的图像上传失败', 'error', 5000);
+ uploadResults = undefined;
+ }
+ } else {
+ console.warn('未找到accessToken,跳过上传');
+ }
+
const edit: Edit = {
id: generateId(),
parentGenerationId: selectedGenerationId || '',
@@ -233,7 +308,8 @@ export const useImageEditing = () => {
maskReferenceAsset,
instruction,
outputAssets,
- timestamp: Date.now()
+ timestamp: Date.now(),
+ uploadResults: uploadResults
};
addEdit(edit);
@@ -246,18 +322,25 @@ export const useImageEditing = () => {
}
setIsGenerating(false);
},
- onError: (error) => {
- console.error('编辑失败:', error);
- const errorMessage = error instanceof Error ? error.message : '未知错误';
- const errorDetails = error instanceof Error ? error.stack : undefined;
- addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails);
- setIsGenerating(false);
- }
- });
+ onError: error => {
+ console.error('编辑失败:', error)
+ const errorMessage = error instanceof Error ? error.message : '未知错误'
+ const errorDetails = error instanceof Error ? error.stack : undefined
+ addToast(`图像编辑失败: ${errorMessage}`, 'error', 5000, errorDetails)
+ setIsGenerating(false)
+ },
+ })
+
+ const cancelEdit = () => {
+ isCancelledRef.current = true
+ setIsGenerating(false)
+ addToast('编辑已中断', 'info', 3000)
+ }
return {
edit: editMutation.mutate,
isEditing: editMutation.isPending,
- error: editMutation.error
- };
-};
\ No newline at end of file
+ error: editMutation.error,
+ cancelEdit,
+ }
+}
diff --git a/src/services/uploadService.ts b/src/services/uploadService.ts
new file mode 100644
index 0000000..b4d3099
--- /dev/null
+++ b/src/services/uploadService.ts
@@ -0,0 +1,103 @@
+// src/services/uploadService.ts
+import { UploadResult } from '../types'
+
+// 上传接口URL
+const UPLOAD_URL = 'https://api.pandorastudio.cn/auth/OSSupload'
+
+/**
+ * 将base64图像数据上传到指定接口
+ * @param base64Data - base64编码的图像数据
+ * @param accessToken - 访问令牌
+ * @returns 上传结果
+ */
+export const uploadImage = async (base64Data: string, accessToken: string): Promise<{ success: boolean; url?: string; error?: string }> => {
+ try {
+ // 将base64数据转换为Blob
+ const byteString = atob(base64Data.split(',')[1])
+ const mimeString = base64Data.split(',')[0].split(':')[1].split(';')[0]
+ const ab = new ArrayBuffer(byteString.length)
+ const ia = new Uint8Array(ab)
+ for (let i = 0; i < byteString.length; i++) {
+ ia[i] = byteString.charCodeAt(i)
+ }
+ const blob = new Blob([ab], { type: mimeString })
+
+ // 创建FormData对象
+ const formData = new FormData()
+ formData.append('file', blob, 'generated-image.png')
+
+ // 发送POST请求
+ const response = await fetch(UPLOAD_URL, {
+ method: 'POST',
+ headers: { accessToken },
+ body: formData,
+ })
+
+ if (!response.ok) {
+ const errorText = await response.text()
+ throw new Error(`上传失败: ${response.status} ${response.statusText} - ${errorText}`)
+ }
+
+ const result = await response.json()
+ // 根据返回格式处理结果: {"code": 200,"msg": "上传成功","data": "9ecbaa0a0.jpg"}
+ if (result.code === 200) {
+ // 使用环境变量中的VITE_UPLOAD_ASSET_URL作为前缀
+ const uploadAssetUrl = import.meta.env.VITE_UPLOAD_ASSET_URL || ''
+ const fullUrl = uploadAssetUrl ? `${uploadAssetUrl}/${result.data}` : result.data
+ return { success: true, url: fullUrl, error: undefined }
+ } else {
+ throw new Error(`上传失败: ${result.msg}`)
+ }
+ } catch (error) {
+ console.error('上传图像时出错:', error)
+ return { success: false, url: undefined, error: error instanceof Error ? error.message : String(error) }
+ }
+}
+
+/**
+ * 上传多个图像
+ * @param base64Images - base64编码的图像数组
+ * @param accessToken - 访问令牌
+ * @returns 上传结果数组
+ */
+export const uploadImages = async (base64Images: string[], accessToken: string): Promise => {
+ try {
+ const results: UploadResult[] = []
+
+ for (let i = 0; i < base64Images.length; i++) {
+ const base64Data = base64Images[i]
+ try {
+ const uploadResult = await uploadImage(base64Data, accessToken)
+ const result: UploadResult = {
+ success: uploadResult.success,
+ url: uploadResult.url,
+ error: uploadResult.error,
+ timestamp: Date.now(),
+ }
+ results.push(result)
+ console.log(`第${i + 1}张图像上传${uploadResult.success ? '成功' : '失败'}:`, uploadResult)
+ } catch (error) {
+ const result: UploadResult = {
+ success: false,
+ error: error instanceof Error ? error.message : String(error),
+ timestamp: Date.now(),
+ }
+ results.push(result)
+ console.error(`第${i + 1}张图像上传失败:`, error)
+ }
+ }
+
+ // 检查是否有任何上传失败
+ const failedUploads = results.filter(r => !r.success)
+ if (failedUploads.length > 0) {
+ console.warn(`${failedUploads.length}张图像上传失败`)
+ } else {
+ console.log(`所有${results.length}张图像上传成功`)
+ }
+
+ return results
+ } catch (error) {
+ console.error('批量上传图像时出错:', error)
+ throw error
+ }
+}
diff --git a/src/store/useAppStore.ts b/src/store/useAppStore.ts
index 9fd33c0..5d745ed 100644
--- a/src/store/useAppStore.ts
+++ b/src/store/useAppStore.ts
@@ -1,6 +1,6 @@
import { create } from 'zustand';
import { devtools, persist } from 'zustand/middleware';
-import { Project, Generation, Edit, SegmentationMask, BrushStroke } from '../types';
+import { Project, Generation, Edit, SegmentationMask, BrushStroke, UploadResult } from '../types';
import { generateId } from '../utils/imageUtils';
// 定义不包含图像数据的轻量级项目结构
@@ -25,6 +25,7 @@ interface LightweightProject {
outputAssetsBlobUrls: string[];
modelVersion: string;
timestamp: number;
+ uploadResults?: UploadResult[];
}>;
edits: Array<{
id: string;
@@ -36,6 +37,7 @@ interface LightweightProject {
// 存储输出资产的Blob URL
outputAssetsBlobUrls: string[];
timestamp: number;
+ uploadResults?: UploadResult[];
}>;
createdAt: number;
updatedAt: number;
@@ -281,7 +283,8 @@ export const useAppStore = create()(
sourceAssets,
outputAssetsBlobUrls,
modelVersion: generation.modelVersion,
- timestamp: generation.timestamp
+ timestamp: generation.timestamp,
+ uploadResults: generation.uploadResults
};
const updatedProject = state.currentProject ? {
@@ -368,7 +371,8 @@ export const useAppStore = create()(
maskReferenceAssetBlobUrl,
instruction: edit.instruction,
outputAssetsBlobUrls,
- timestamp: edit.timestamp
+ timestamp: edit.timestamp,
+ uploadResults: edit.uploadResults
};
if (!state.currentProject) return {};
diff --git a/src/types/index.ts b/src/types/index.ts
index 7bf5ee9..a163b21 100644
--- a/src/types/index.ts
+++ b/src/types/index.ts
@@ -8,6 +8,13 @@ export interface Asset {
checksum: string;
}
+export interface UploadResult {
+ success: boolean;
+ url?: string;
+ error?: string;
+ timestamp: number;
+}
+
export interface Generation {
id: string;
prompt: string;
@@ -20,6 +27,7 @@ export interface Generation {
modelVersion: string;
timestamp: number;
costEstimate?: number;
+ uploadResults?: UploadResult[];
}
export interface Edit {
@@ -30,6 +38,7 @@ export interface Edit {
instruction: string;
outputAssets: Asset[];
timestamp: number;
+ uploadResults?: UploadResult[];
}
export interface Project {