You've already forked Nano-Banana-AI-Image-Editor
436 lines
15 KiB
TypeScript
436 lines
15 KiB
TypeScript
import { create } from 'zustand';
|
||
import { devtools, persist } from 'zustand/middleware';
|
||
import { Project, Generation, Edit, SegmentationMask, BrushStroke } from '../types';
|
||
import { generateId } from '../utils/imageUtils';
|
||
|
||
// 定义不包含图像数据的轻量级项目结构
|
||
interface LightweightProject {
|
||
id: string;
|
||
title: string;
|
||
generations: Array<{
|
||
id: string;
|
||
prompt: string;
|
||
parameters: Generation['parameters'];
|
||
sourceAssets: Array<{
|
||
id: string;
|
||
type: 'original';
|
||
mime: string;
|
||
width: number;
|
||
height: number;
|
||
checksum: string;
|
||
// 存储Blob URL而不是base64数据
|
||
blobUrl: string;
|
||
}>;
|
||
// 存储输出资产的Blob URL
|
||
outputAssetsBlobUrls: string[];
|
||
modelVersion: string;
|
||
timestamp: number;
|
||
}>;
|
||
edits: Array<{
|
||
id: string;
|
||
parentGenerationId: string;
|
||
maskAssetId?: string;
|
||
// 存储遮罩参考资产的Blob URL
|
||
maskReferenceAssetBlobUrl?: string;
|
||
instruction: string;
|
||
// 存储输出资产的Blob URL
|
||
outputAssetsBlobUrls: string[];
|
||
timestamp: number;
|
||
}>;
|
||
createdAt: number;
|
||
updatedAt: number;
|
||
}
|
||
|
||
interface AppState {
|
||
// 当前项目(轻量级版本,不包含实际图像数据)
|
||
currentProject: LightweightProject | null;
|
||
|
||
// 画布状态
|
||
canvasImage: string | null;
|
||
canvasZoom: number;
|
||
canvasPan: { x: number; y: number };
|
||
|
||
// 上传状态
|
||
uploadedImages: string[];
|
||
editReferenceImages: string[];
|
||
|
||
// 用于绘制遮罩的画笔描边
|
||
brushStrokes: BrushStroke[];
|
||
brushSize: number;
|
||
showMasks: boolean;
|
||
|
||
// 生成状态
|
||
isGenerating: boolean;
|
||
currentPrompt: string;
|
||
temperature: number;
|
||
seed: number | null;
|
||
|
||
// 历史记录和变体
|
||
selectedGenerationId: string | null;
|
||
selectedEditId: string | null;
|
||
showHistory: boolean;
|
||
|
||
// 面板可见性
|
||
showPromptPanel: boolean;
|
||
|
||
// UI状态
|
||
selectedTool: 'generate' | 'edit' | 'mask';
|
||
|
||
// 存储Blob对象的Map
|
||
blobStore: Map<string, Blob>;
|
||
|
||
// 操作
|
||
setCurrentProject: (project: LightweightProject | null) => void;
|
||
setCanvasImage: (url: string | null) => void;
|
||
setCanvasZoom: (zoom: number) => void;
|
||
setCanvasPan: (pan: { x: number; y: number }) => void;
|
||
|
||
addUploadedImage: (url: string) => void;
|
||
removeUploadedImage: (index: number) => void;
|
||
clearUploadedImages: () => void;
|
||
|
||
addEditReferenceImage: (url: string) => void;
|
||
removeEditReferenceImage: (index: number) => void;
|
||
clearEditReferenceImages: () => void;
|
||
|
||
addBrushStroke: (stroke: BrushStroke) => void;
|
||
clearBrushStrokes: () => void;
|
||
setBrushSize: (size: number) => void;
|
||
setShowMasks: (show: boolean) => void;
|
||
|
||
setIsGenerating: (generating: boolean) => void;
|
||
setCurrentPrompt: (prompt: string) => void;
|
||
setTemperature: (temp: number) => void;
|
||
setSeed: (seed: number | null) => void;
|
||
|
||
addGeneration: (generation: Generation) => void;
|
||
addEdit: (edit: Edit) => void;
|
||
selectGeneration: (id: string | null) => void;
|
||
selectEdit: (id: string | null) => void;
|
||
setShowHistory: (show: boolean) => void;
|
||
|
||
setShowPromptPanel: (show: boolean) => void;
|
||
|
||
setSelectedTool: (tool: 'generate' | 'edit' | 'mask') => void;
|
||
|
||
// Blob存储操作
|
||
addBlob: (blob: Blob) => string;
|
||
getBlob: (url: string) => Blob | undefined;
|
||
cleanupOldHistory: () => void;
|
||
}
|
||
|
||
export const useAppStore = create<AppState>()(
|
||
devtools(
|
||
persist(
|
||
(set, get) => ({
|
||
// 初始状态
|
||
currentProject: null,
|
||
canvasImage: null,
|
||
canvasZoom: 1,
|
||
canvasPan: { x: 0, y: 0 },
|
||
|
||
uploadedImages: [],
|
||
editReferenceImages: [],
|
||
|
||
brushStrokes: [],
|
||
brushSize: 20,
|
||
showMasks: true,
|
||
|
||
isGenerating: false,
|
||
currentPrompt: '',
|
||
temperature: 1,
|
||
seed: null,
|
||
|
||
selectedGenerationId: null,
|
||
selectedEditId: null,
|
||
showHistory: true,
|
||
|
||
showPromptPanel: true,
|
||
|
||
selectedTool: 'generate',
|
||
|
||
// Blob存储(不在持久化中保存)
|
||
blobStore: new Map(),
|
||
|
||
// 操作
|
||
setCurrentProject: (project) => set({ currentProject: project }),
|
||
setCanvasImage: (url) => set({ canvasImage: url }),
|
||
setCanvasZoom: (zoom) => set({ canvasZoom: zoom }),
|
||
setCanvasPan: (pan) => set({ canvasPan: pan }),
|
||
|
||
addUploadedImage: (url) => set((state) => ({
|
||
uploadedImages: [...state.uploadedImages, url]
|
||
})),
|
||
removeUploadedImage: (index) => set((state) => ({
|
||
uploadedImages: state.uploadedImages.filter((_, i) => i !== index)
|
||
})),
|
||
clearUploadedImages: () => set({ uploadedImages: [] }),
|
||
|
||
addEditReferenceImage: (url) => set((state) => ({
|
||
editReferenceImages: [...state.editReferenceImages, url]
|
||
})),
|
||
removeEditReferenceImage: (index) => set((state) => ({
|
||
editReferenceImages: state.editReferenceImages.filter((_, i) => i !== index)
|
||
})),
|
||
clearEditReferenceImages: () => set({ editReferenceImages: [] }),
|
||
|
||
addBrushStroke: (stroke) => set((state) => ({
|
||
brushStrokes: [...state.brushStrokes, stroke]
|
||
})),
|
||
clearBrushStrokes: () => set({ brushStrokes: [] }),
|
||
setBrushSize: (size) => set({ brushSize: size }),
|
||
setShowMasks: (show) => set({ showMasks: show }),
|
||
|
||
setIsGenerating: (generating) => set({ isGenerating: generating }),
|
||
setCurrentPrompt: (prompt) => set({ currentPrompt: prompt }),
|
||
setTemperature: (temp) => set({ temperature: temp }),
|
||
setSeed: (seed) => set({ seed: seed }),
|
||
|
||
// 添加Blob到存储并返回URL
|
||
addBlob: (blob: Blob) => {
|
||
const url = URL.createObjectURL(blob);
|
||
set((state) => {
|
||
const newBlobStore = new Map(state.blobStore);
|
||
newBlobStore.set(url, blob);
|
||
return { blobStore: newBlobStore };
|
||
});
|
||
return url;
|
||
},
|
||
|
||
// 从存储中获取Blob
|
||
getBlob: (url: string) => {
|
||
return get().blobStore.get(url);
|
||
},
|
||
|
||
addGeneration: (generation) => set((state) => {
|
||
// 将base64图像数据转换为Blob并存储
|
||
const sourceAssets = generation.sourceAssets.map(asset => {
|
||
if (asset.url.startsWith('data:')) {
|
||
// 从base64创建Blob
|
||
const base64 = asset.url.split(',')[1];
|
||
const byteString = atob(base64);
|
||
const mimeString = asset.url.split(',')[0].split(':')[1].split(';')[0];
|
||
const ab = new ArrayBuffer(byteString.length);
|
||
const ia = new Uint8Array(ab);
|
||
for (let i = 0; i < byteString.length; i++) {
|
||
ia[i] = byteString.charCodeAt(i);
|
||
}
|
||
const blob = new Blob([ab], { type: mimeString });
|
||
const blobUrl = URL.createObjectURL(blob);
|
||
|
||
// 存储Blob对象
|
||
set((innerState) => {
|
||
const newBlobStore = new Map(innerState.blobStore);
|
||
newBlobStore.set(blobUrl, blob);
|
||
return { blobStore: newBlobStore };
|
||
});
|
||
|
||
return {
|
||
id: asset.id,
|
||
type: asset.type,
|
||
mime: asset.mime,
|
||
width: asset.width,
|
||
height: asset.height,
|
||
checksum: asset.checksum,
|
||
blobUrl
|
||
};
|
||
}
|
||
return {
|
||
id: asset.id,
|
||
type: asset.type,
|
||
mime: asset.mime,
|
||
width: asset.width,
|
||
height: asset.height,
|
||
checksum: asset.checksum,
|
||
blobUrl: asset.url
|
||
};
|
||
});
|
||
|
||
// 将输出资产转换为Blob URL
|
||
const outputAssetsBlobUrls = generation.outputAssets.map(asset => {
|
||
if (asset.url.startsWith('data:')) {
|
||
// 从base64创建Blob
|
||
const base64 = asset.url.split(',')[1];
|
||
const byteString = atob(base64);
|
||
const mimeString = asset.url.split(',')[0].split(':')[1].split(';')[0];
|
||
const ab = new ArrayBuffer(byteString.length);
|
||
const ia = new Uint8Array(ab);
|
||
for (let i = 0; i < byteString.length; i++) {
|
||
ia[i] = byteString.charCodeAt(i);
|
||
}
|
||
const blob = new Blob([ab], { type: mimeString });
|
||
const blobUrl = URL.createObjectURL(blob);
|
||
|
||
// 存储Blob对象
|
||
set((innerState) => {
|
||
const newBlobStore = new Map(innerState.blobStore);
|
||
newBlobStore.set(blobUrl, blob);
|
||
return { blobStore: newBlobStore };
|
||
});
|
||
|
||
return blobUrl;
|
||
}
|
||
return asset.url;
|
||
});
|
||
|
||
// 创建轻量级生成记录
|
||
const lightweightGeneration = {
|
||
id: generation.id,
|
||
prompt: generation.prompt,
|
||
parameters: generation.parameters,
|
||
sourceAssets,
|
||
outputAssetsBlobUrls,
|
||
modelVersion: generation.modelVersion,
|
||
timestamp: generation.timestamp
|
||
};
|
||
|
||
const updatedProject = state.currentProject ? {
|
||
...state.currentProject,
|
||
generations: [...state.currentProject.generations, lightweightGeneration],
|
||
updatedAt: Date.now()
|
||
} : {
|
||
// 如果没有项目,创建一个新项目包含此生成记录
|
||
id: generateId(),
|
||
title: '未命名项目',
|
||
generations: [lightweightGeneration],
|
||
edits: [],
|
||
createdAt: Date.now(),
|
||
updatedAt: Date.now()
|
||
};
|
||
|
||
// 清理旧记录以保持在限制内
|
||
if (updatedProject.generations.length > 10) {
|
||
updatedProject.generations.splice(0, updatedProject.generations.length - 10);
|
||
}
|
||
|
||
return {
|
||
currentProject: updatedProject
|
||
};
|
||
}),
|
||
|
||
addEdit: (edit) => set((state) => {
|
||
// 将遮罩参考资产转换为Blob URL(如果存在)
|
||
let maskReferenceAssetBlobUrl: string | undefined;
|
||
if (edit.maskReferenceAsset && edit.maskReferenceAsset.url.startsWith('data:')) {
|
||
const base64 = edit.maskReferenceAsset.url.split(',')[1];
|
||
const byteString = atob(base64);
|
||
const mimeString = edit.maskReferenceAsset.url.split(',')[0].split(':')[1].split(';')[0];
|
||
const ab = new ArrayBuffer(byteString.length);
|
||
const ia = new Uint8Array(ab);
|
||
for (let i = 0; i < byteString.length; i++) {
|
||
ia[i] = byteString.charCodeAt(i);
|
||
}
|
||
const blob = new Blob([ab], { type: mimeString });
|
||
maskReferenceAssetBlobUrl = URL.createObjectURL(blob);
|
||
|
||
// 存储Blob对象
|
||
set((innerState) => {
|
||
const newBlobStore = new Map(innerState.blobStore);
|
||
newBlobStore.set(maskReferenceAssetBlobUrl!, blob);
|
||
return { blobStore: newBlobStore };
|
||
});
|
||
} else if (edit.maskReferenceAsset) {
|
||
maskReferenceAssetBlobUrl = edit.maskReferenceAsset.url;
|
||
}
|
||
|
||
// 将输出资产转换为Blob URL
|
||
const outputAssetsBlobUrls = edit.outputAssets.map(asset => {
|
||
if (asset.url.startsWith('data:')) {
|
||
// 从base64创建Blob
|
||
const base64 = asset.url.split(',')[1];
|
||
const byteString = atob(base64);
|
||
const mimeString = asset.url.split(',')[0].split(':')[1].split(';')[0];
|
||
const ab = new ArrayBuffer(byteString.length);
|
||
const ia = new Uint8Array(ab);
|
||
for (let i = 0; i < byteString.length; i++) {
|
||
ia[i] = byteString.charCodeAt(i);
|
||
}
|
||
const blob = new Blob([ab], { type: mimeString });
|
||
const blobUrl = URL.createObjectURL(blob);
|
||
|
||
// 存储Blob对象
|
||
set((innerState) => {
|
||
const newBlobStore = new Map(innerState.blobStore);
|
||
newBlobStore.set(blobUrl, blob);
|
||
return { blobStore: newBlobStore };
|
||
});
|
||
|
||
return blobUrl;
|
||
}
|
||
return asset.url;
|
||
});
|
||
|
||
// 创建轻量级编辑记录
|
||
const lightweightEdit = {
|
||
id: edit.id,
|
||
parentGenerationId: edit.parentGenerationId,
|
||
maskAssetId: edit.maskAssetId,
|
||
maskReferenceAssetBlobUrl,
|
||
instruction: edit.instruction,
|
||
outputAssetsBlobUrls,
|
||
timestamp: edit.timestamp
|
||
};
|
||
|
||
if (!state.currentProject) return {};
|
||
|
||
const updatedProject = {
|
||
...state.currentProject,
|
||
edits: [...state.currentProject.edits, lightweightEdit],
|
||
updatedAt: Date.now()
|
||
};
|
||
|
||
// 清理旧记录以保持在限制内
|
||
if (updatedProject.edits.length > 10) {
|
||
updatedProject.edits.splice(0, updatedProject.edits.length - 10);
|
||
}
|
||
|
||
return {
|
||
currentProject: updatedProject
|
||
};
|
||
}),
|
||
|
||
selectGeneration: (id) => set({ selectedGenerationId: id }),
|
||
selectEdit: (id) => set({ selectedEditId: id }),
|
||
setShowHistory: (show) => set({ showHistory: show }),
|
||
|
||
setShowPromptPanel: (show) => set({ showPromptPanel: show }),
|
||
|
||
setSelectedTool: (tool) => set({ selectedTool: tool }),
|
||
|
||
// 清理旧的历史记录,保留最多10条
|
||
cleanupOldHistory: () => set((state) => {
|
||
if (!state.currentProject) return {};
|
||
|
||
const generations = [...state.currentProject.generations];
|
||
const edits = [...state.currentProject.edits];
|
||
|
||
// 如果生成记录超过10条,只保留最新的10条
|
||
if (generations.length > 10) {
|
||
generations.splice(0, generations.length - 10);
|
||
}
|
||
|
||
// 如果编辑记录超过10条,只保留最新的10条
|
||
if (edits.length > 10) {
|
||
edits.splice(0, edits.length - 10);
|
||
}
|
||
|
||
return {
|
||
currentProject: {
|
||
...state.currentProject,
|
||
generations,
|
||
edits,
|
||
updatedAt: Date.now()
|
||
}
|
||
};
|
||
})
|
||
}),
|
||
{
|
||
name: 'nano-banana-store',
|
||
partialize: (state) => ({
|
||
currentProject: state.currentProject,
|
||
// 我们只持久化轻量级项目数据,不包含Blob对象
|
||
})
|
||
}
|
||
)
|
||
)
|
||
); |