Compare commits

...

1 commit

Author SHA1 Message Date
GuaGua
850dda51ac fix(translate): use live provider config for connection test 2026-04-07 18:47:44 -07:00
3 changed files with 81 additions and 15 deletions

View file

@ -0,0 +1,57 @@
import type { LLMProviderConfig } from "@/types/config/provider"
import { beforeEach, describe, expect, it, vi } from "vitest"
const getModelByConfigMock = vi.fn()
const generateTextMock = vi.fn()
vi.mock("@/utils/providers/model", () => ({
getModelByConfig: getModelByConfigMock,
}))
vi.mock("ai", () => ({
generateText: generateTextMock,
}))
const providerConfig: LLMProviderConfig = {
id: "minimax-default",
name: "MiniMax",
enabled: true,
provider: "minimax",
apiKey: "test-key",
baseURL: "https://api.minimaxi.com/anthropic/v1",
model: {
model: "MiniMax-M2.7",
isCustomModel: false,
customModel: null,
},
}
describe("aiTranslate", () => {
beforeEach(() => {
vi.resetModules()
vi.clearAllMocks()
})
it("builds the model from the current provider config", async () => {
getModelByConfigMock.mockResolvedValue("mock-model")
generateTextMock.mockResolvedValue({ text: "你好" })
const { aiTranslate } = await import("../ai")
const promptResolver = vi.fn().mockResolvedValue({
systemPrompt: "system prompt",
prompt: "translate this",
})
const result = await aiTranslate("Hi", "Chinese", providerConfig, promptResolver)
expect(getModelByConfigMock).toHaveBeenCalledWith(providerConfig)
expect(promptResolver).toHaveBeenCalledWith("Chinese", "Hi", undefined)
expect(generateTextMock).toHaveBeenCalledWith(expect.objectContaining({
model: "mock-model",
system: "system prompt",
prompt: "translate this",
maxRetries: 0,
}))
expect(result).toBe("你好")
})
})

View file

@ -3,7 +3,7 @@ import type { ArticleContent } from "@/types/content"
import type { TranslatePromptOptions, TranslatePromptResult } from "@/utils/prompts/translate"
import { generateText } from "ai"
import { extractAISDKErrorMessage } from "@/utils/error/extract-message"
import { getModelById } from "@/utils/providers/model"
import { getModelByConfig } from "@/utils/providers/model"
import { resolveModelId } from "@/utils/providers/model-id"
import { getProviderOptionsWithOverride } from "@/utils/providers/options"
@ -22,9 +22,9 @@ export async function aiTranslate(
promptResolver: PromptResolver,
options?: { isBatch?: boolean, content?: ArticleContent },
) {
const { id: providerId, model: providerModel, provider, providerOptions: userProviderOptions, temperature } = providerConfig
const { model: providerModel, provider, providerOptions: userProviderOptions, temperature } = providerConfig
const modelName = resolveModelId(providerModel)
const model = await getModelById(providerId)
const model = await getModelByConfig(providerConfig)
const providerOptions = getProviderOptionsWithOverride(modelName ?? "", provider, userProviderOptions)
const { systemPrompt, prompt } = await promptResolver(targetLangName, text, options)

View file

@ -1,4 +1,5 @@
import type { Config } from "@/types/config/config"
import type { LLMProviderConfig } from "@/types/config/provider"
import { storage } from "#imports"
import { createAlibaba } from "@ai-sdk/alibaba"
import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"
@ -63,18 +64,7 @@ const CUSTOM_HEADER_MAP: Partial<Record<keyof typeof CREATE_AI_MAPPER, Record<st
anthropic: { "anthropic-dangerous-direct-browser-access": "true" },
}
async function getLanguageModelById(providerId: string) {
const config = await storage.getItem<Config>(`local:${CONFIG_STORAGE_KEY}`)
if (!config) {
throw new Error("Config not found")
}
const LLMProvidersConfig = getLLMProvidersConfig(config.providersConfig)
const providerConfig = getProviderConfigById(LLMProvidersConfig, providerId)
if (!providerConfig) {
throw new Error(`Provider ${providerId} not found`)
}
function getLanguageModelByConfig(providerConfig: LLMProviderConfig) {
const customHeaders = CUSTOM_HEADER_MAP[providerConfig.provider]
const connectionOptions = compactObject(providerConfig.connectionOptions ?? {})
@ -103,6 +93,25 @@ async function getLanguageModelById(providerId: string) {
return provider.languageModel(modelId)
}
async function getLanguageModelById(providerId: string) {
const config = await storage.getItem<Config>(`local:${CONFIG_STORAGE_KEY}`)
if (!config) {
throw new Error("Config not found")
}
const LLMProvidersConfig = getLLMProvidersConfig(config.providersConfig)
const providerConfig = getProviderConfigById(LLMProvidersConfig, providerId)
if (!providerConfig) {
throw new Error(`Provider ${providerId} not found`)
}
return getLanguageModelByConfig(providerConfig)
}
export async function getModelById(providerId: string) {
return getLanguageModelById(providerId)
}
export async function getModelByConfig(providerConfig: LLMProviderConfig) {
return getLanguageModelByConfig(providerConfig)
}