All files / lib ai-provider.ts

0% Statements 0/61
0% Branches 0/43
0% Functions 0/14
0% Lines 0/53

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152                                                                                                                                                                                                                                                                                                               
import { createOpenAI } from "@ai-sdk/openai";
import type { LanguageModel } from "ai";
import { isProductionLike } from "@/lib/runtime-env";
import { normalizeOllamaBaseUrl } from "@/lib/ollama-url";
 
export const SUPPORTED_AI_PROVIDERS = ["openai", "ollama"] as const;
 
export type AIProviderName = (typeof SUPPORTED_AI_PROVIDERS)[number];
 
export type ChatModelConfig = {
  provider: AIProviderName;
  modelId: string;
  model: LanguageModel;
};
 
export type ChatModelOverrides = {
  provider?: AIProviderName;
  openaiApiKey?: string;
  modelId?: string;
  baseUrl?: string;
};
 
export function isAIProviderName(value: string): value is AIProviderName {
  return SUPPORTED_AI_PROVIDERS.includes(value as AIProviderName);
}
 
export function getSupportedAIProviderList(): string {
  return SUPPORTED_AI_PROVIDERS.join(", ");
}
 
function resolveOpenAIKey(overrides: ChatModelOverrides): string {
  const key = overrides.openaiApiKey?.trim() || process.env.OPENAI_API_KEY?.trim();
  if (!key) throw new Error("OpenAI API key is required.");
  return key;
}
 
function openAIModelId(overrides: ChatModelOverrides): string {
  const configuredProvider = process.env.AI_PROVIDER?.trim().toLowerCase();
  return (
    overrides.modelId ??
    process.env.OPENAI_MODEL ??
    (configuredProvider === "openai" ? process.env.AI_MODEL : undefined) ??
    "gpt-4o-mini"
  );
}
 
function openAIConfig(overrides: ChatModelOverrides): ChatModelConfig {
  const modelId = openAIModelId(overrides);
  const apiKey = resolveOpenAIKey(overrides);
  const openai = createOpenAI({
    baseURL: overrides.baseUrl ?? process.env.OPENAI_BASE_URL,
    apiKey,
  });
  return { provider: "openai", modelId, model: openai.chat(modelId) };
}
 
function ollamaModelId(overrides: ChatModelOverrides): string {
  const configuredProvider = process.env.AI_PROVIDER?.trim().toLowerCase();
  return (
    overrides.modelId ??
    process.env.OLLAMA_MODEL ??
    (configuredProvider === "ollama" ? process.env.AI_MODEL : undefined) ??
    "qwen2.5:3b"
  );
}
 
const OLLAMA_DEFAULT_BASE_URL = "http://localhost:11434/v1";
 
function resolveOllamaBaseUrl(overrides: ChatModelOverrides): string {
  const raw = overrides.baseUrl ?? process.env.OLLAMA_BASE_URL;
  // normalizeOllamaBaseUrl appends /v1 if missing — guards against bare host:port in env.
  return (raw ? normalizeOllamaBaseUrl(raw) : null) ?? OLLAMA_DEFAULT_BASE_URL;
}
 
function ollamaConfig(overrides: ChatModelOverrides): ChatModelConfig {
  const modelId = ollamaModelId(overrides);
  const openaiCompatible = createOpenAI({
    baseURL: resolveOllamaBaseUrl(overrides),
    apiKey: "ollama",
  });
  // Ollama OpenAI-compatible endpoint works best with chat mode.
  return { provider: "ollama", modelId, model: openaiCompatible.chat(modelId) };
}
 
function resolveDefaultProvider(): AIProviderName {
  return isProductionLike() ? "openai" : "ollama";
}
 
function resolveProvider(overrides: ChatModelOverrides): AIProviderName {
  if (overrides.provider) return overrides.provider;
 
  const provider = process.env.AI_PROVIDER?.trim().toLowerCase();
 
  if (!provider) return resolveDefaultProvider();
 
  if (isAIProviderName(provider)) return provider;
 
  throw new Error(
    `Unsupported AI_PROVIDER "${provider}". Supported values: ${getSupportedAIProviderList()}.`,
  );
}
 
function getProviderResolutionOrder(overrides: ChatModelOverrides): AIProviderName[] {
  if (overrides.provider) return [overrides.provider];
 
  const primaryProvider = resolveProvider(overrides);
  const fallbackProviders = SUPPORTED_AI_PROVIDERS.filter(
    (provider) => provider !== primaryProvider,
  );
 
  return [primaryProvider, ...fallbackProviders];
}
 
function buildConfig(
  provider: AIProviderName,
  overrides: ChatModelOverrides,
): ChatModelConfig {
  const merged = { ...overrides, provider };
  switch (provider) {
    case "openai":
      return openAIConfig(merged);
    case "ollama":
      return ollamaConfig(merged);
  }
}
 
export function getChatModelCandidates(
  overrides: ChatModelOverrides = {},
): ChatModelConfig[] {
  const providerOrder = getProviderResolutionOrder(overrides);
  const candidates: ChatModelConfig[] = [];
  let lastError: unknown = null;
 
  for (const provider of providerOrder) {
    try {
      candidates.push(buildConfig(provider, overrides));
    } catch (error) {
      lastError = error;
 
      if (overrides.provider) {
        throw error;
      }
    }
  }
 
  if (candidates.length > 0) return candidates;
 
  if (lastError instanceof Error) throw lastError;
 
  throw new Error("No available AI provider configuration was resolved.");
}