mirror of
https://github.com/mengxi-ream/read-frog.git
synced 2026-04-30 01:56:46 +00:00
Compare commits
1 commit
main
...
fix/connec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
850dda51ac |
3 changed files with 81 additions and 15 deletions
57
src/utils/host/translate/api/__tests__/ai.test.ts
Normal file
57
src/utils/host/translate/api/__tests__/ai.test.ts
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import type { LLMProviderConfig } from "@/types/config/provider"
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest"
|
||||
|
||||
const getModelByConfigMock = vi.fn()
|
||||
const generateTextMock = vi.fn()
|
||||
|
||||
vi.mock("@/utils/providers/model", () => ({
|
||||
getModelByConfig: getModelByConfigMock,
|
||||
}))
|
||||
|
||||
vi.mock("ai", () => ({
|
||||
generateText: generateTextMock,
|
||||
}))
|
||||
|
||||
const providerConfig: LLMProviderConfig = {
|
||||
id: "minimax-default",
|
||||
name: "MiniMax",
|
||||
enabled: true,
|
||||
provider: "minimax",
|
||||
apiKey: "test-key",
|
||||
baseURL: "https://api.minimaxi.com/anthropic/v1",
|
||||
model: {
|
||||
model: "MiniMax-M2.7",
|
||||
isCustomModel: false,
|
||||
customModel: null,
|
||||
},
|
||||
}
|
||||
|
||||
describe("aiTranslate", () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it("builds the model from the current provider config", async () => {
|
||||
getModelByConfigMock.mockResolvedValue("mock-model")
|
||||
generateTextMock.mockResolvedValue({ text: "你好" })
|
||||
|
||||
const { aiTranslate } = await import("../ai")
|
||||
const promptResolver = vi.fn().mockResolvedValue({
|
||||
systemPrompt: "system prompt",
|
||||
prompt: "translate this",
|
||||
})
|
||||
|
||||
const result = await aiTranslate("Hi", "Chinese", providerConfig, promptResolver)
|
||||
|
||||
expect(getModelByConfigMock).toHaveBeenCalledWith(providerConfig)
|
||||
expect(promptResolver).toHaveBeenCalledWith("Chinese", "Hi", undefined)
|
||||
expect(generateTextMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: "mock-model",
|
||||
system: "system prompt",
|
||||
prompt: "translate this",
|
||||
maxRetries: 0,
|
||||
}))
|
||||
expect(result).toBe("你好")
|
||||
})
|
||||
})
|
||||
|
|
@ -3,7 +3,7 @@ import type { ArticleContent } from "@/types/content"
|
|||
import type { TranslatePromptOptions, TranslatePromptResult } from "@/utils/prompts/translate"
|
||||
import { generateText } from "ai"
|
||||
import { extractAISDKErrorMessage } from "@/utils/error/extract-message"
|
||||
import { getModelById } from "@/utils/providers/model"
|
||||
import { getModelByConfig } from "@/utils/providers/model"
|
||||
import { resolveModelId } from "@/utils/providers/model-id"
|
||||
import { getProviderOptionsWithOverride } from "@/utils/providers/options"
|
||||
|
||||
|
|
@ -22,9 +22,9 @@ export async function aiTranslate(
|
|||
promptResolver: PromptResolver,
|
||||
options?: { isBatch?: boolean, content?: ArticleContent },
|
||||
) {
|
||||
const { id: providerId, model: providerModel, provider, providerOptions: userProviderOptions, temperature } = providerConfig
|
||||
const { model: providerModel, provider, providerOptions: userProviderOptions, temperature } = providerConfig
|
||||
const modelName = resolveModelId(providerModel)
|
||||
const model = await getModelById(providerId)
|
||||
const model = await getModelByConfig(providerConfig)
|
||||
|
||||
const providerOptions = getProviderOptionsWithOverride(modelName ?? "", provider, userProviderOptions)
|
||||
const { systemPrompt, prompt } = await promptResolver(targetLangName, text, options)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import type { Config } from "@/types/config/config"
|
||||
import type { LLMProviderConfig } from "@/types/config/provider"
|
||||
import { storage } from "#imports"
|
||||
import { createAlibaba } from "@ai-sdk/alibaba"
|
||||
import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"
|
||||
|
|
@ -63,18 +64,7 @@ const CUSTOM_HEADER_MAP: Partial<Record<keyof typeof CREATE_AI_MAPPER, Record<st
|
|||
anthropic: { "anthropic-dangerous-direct-browser-access": "true" },
|
||||
}
|
||||
|
||||
async function getLanguageModelById(providerId: string) {
|
||||
const config = await storage.getItem<Config>(`local:${CONFIG_STORAGE_KEY}`)
|
||||
if (!config) {
|
||||
throw new Error("Config not found")
|
||||
}
|
||||
|
||||
const LLMProvidersConfig = getLLMProvidersConfig(config.providersConfig)
|
||||
const providerConfig = getProviderConfigById(LLMProvidersConfig, providerId)
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Provider ${providerId} not found`)
|
||||
}
|
||||
|
||||
function getLanguageModelByConfig(providerConfig: LLMProviderConfig) {
|
||||
const customHeaders = CUSTOM_HEADER_MAP[providerConfig.provider]
|
||||
const connectionOptions = compactObject(providerConfig.connectionOptions ?? {})
|
||||
|
||||
|
|
@ -103,6 +93,25 @@ async function getLanguageModelById(providerId: string) {
|
|||
return provider.languageModel(modelId)
|
||||
}
|
||||
|
||||
async function getLanguageModelById(providerId: string) {
|
||||
const config = await storage.getItem<Config>(`local:${CONFIG_STORAGE_KEY}`)
|
||||
if (!config) {
|
||||
throw new Error("Config not found")
|
||||
}
|
||||
|
||||
const LLMProvidersConfig = getLLMProvidersConfig(config.providersConfig)
|
||||
const providerConfig = getProviderConfigById(LLMProvidersConfig, providerId)
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Provider ${providerId} not found`)
|
||||
}
|
||||
|
||||
return getLanguageModelByConfig(providerConfig)
|
||||
}
|
||||
|
||||
export async function getModelById(providerId: string) {
|
||||
return getLanguageModelById(providerId)
|
||||
}
|
||||
|
||||
export async function getModelByConfig(providerConfig: LLMProviderConfig) {
|
||||
return getLanguageModelByConfig(providerConfig)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue