diff --git a/src/utils/ai/cacheStrategy.test.ts b/src/utils/ai/cacheStrategy.test.ts new file mode 100644 index 000000000..a37d900a9 --- /dev/null +++ b/src/utils/ai/cacheStrategy.test.ts @@ -0,0 +1,273 @@ +import { describe, expect, test } from "bun:test"; +import { applyCacheControl } from "./cacheStrategy"; +import type { ModelMessage } from "ai"; + +describe("applyCacheControl", () => { + test("should not apply cache control for non-Anthropic models", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there" }, + ]; + + const result = applyCacheControl(messages, "openai:gpt-5"); + expect(result).toEqual(messages); + }); + + test("should not apply cache control with less than 2 messages", () => { + const messages: ModelMessage[] = [{ role: "user", content: "Hello" }]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + expect(result).toEqual(messages); + }); + + test("should apply single cache breakpoint for short conversation", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "What is the capital of France? ".repeat(200) }, // ~6400 chars > 1024 tokens + { role: "assistant", content: "Paris is the capital. ".repeat(100) }, + { role: "user", content: "What about Germany?" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // With the improved strategy, should cache at index 1 (second-to-last message) + // First message may also be cached if it has enough content + const hasCaching = result.some((msg) => msg.providerOptions?.anthropic?.cacheControl); + expect(hasCaching).toBe(true); + + // The last message (current user input) should never be cached + expect(result[2].providerOptions?.anthropic?.cacheControl).toBeUndefined(); + }); + + test("should cache system message with 1h TTL", () => { + const largeSystemPrompt = "You are a helpful assistant. ".repeat(200); // ~6000 chars + const messages: ModelMessage[] = [ + { role: "system", content: largeSystemPrompt }, + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi!" }, + { role: "user", content: "How are you?" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // System message should be cached with 1h TTL + expect(result[0].providerOptions?.anthropic?.cacheControl).toEqual({ + type: "ephemeral", + ttl: "1h", + }); + + // Should also cache before last message with 5m TTL + expect(result[2].providerOptions?.anthropic?.cacheControl).toEqual({ + type: "ephemeral", + ttl: "5m", + }); + }); + + test("should apply multiple breakpoints for long conversation", () => { + const messages: ModelMessage[] = [ + { role: "system", content: "System instructions. ".repeat(200) }, // Large system + { role: "user", content: "Question 1 ".repeat(100) }, + { role: "assistant", content: "Answer 1 ".repeat(100) }, + { role: "user", content: "Question 2 ".repeat(100) }, + { role: "assistant", content: "Answer 2 ".repeat(100) }, + { role: "user", content: "Question 3 ".repeat(100) }, + { role: "assistant", content: "Answer 3 ".repeat(100) }, + { role: "user", content: "Question 4 ".repeat(100) }, + { role: "assistant", content: "Answer 4 ".repeat(100) }, + { role: "user", content: "Question 5" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Count breakpoints + const breakpointIndices = result + .map((msg, idx) => (msg.providerOptions?.anthropic?.cacheControl ? idx : -1)) + .filter((idx) => idx >= 0); + + // Should have multiple breakpoints (max 4) + expect(breakpointIndices.length).toBeGreaterThan(1); + expect(breakpointIndices.length).toBeLessThanOrEqual(4); + + // System message should have 1h TTL + const systemCacheControl = result[0].providerOptions?.anthropic?.cacheControl; + if ( + systemCacheControl && + typeof systemCacheControl === "object" && + "ttl" in systemCacheControl + ) { + expect(systemCacheControl.ttl).toBe("1h"); + } + + // Last cached message should have 5m TTL + const lastCachedIdx = breakpointIndices[breakpointIndices.length - 1]; + const lastCacheControl = result[lastCachedIdx].providerOptions?.anthropic?.cacheControl; + if (lastCacheControl && typeof lastCacheControl === "object" && "ttl" in lastCacheControl) { + expect(lastCacheControl.ttl).toBe("5m"); + } + }); + + test("should respect Haiku minimum token requirement (2048)", () => { + // Small messages that don't meet Haiku threshold + const messages: ModelMessage[] = [ + { role: "user", content: "Short question" }, // ~60 chars < 2048 tokens + { role: "assistant", content: "Short answer" }, + { role: "user", content: "Another question" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-haiku-3-5"); + + // Should not apply caching for Haiku with small content + const hasCaching = result.some((msg) => msg.providerOptions?.anthropic?.cacheControl); + expect(hasCaching).toBe(false); + }); + + test("should apply caching for Haiku with sufficient content", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "Long message ".repeat(400) }, // ~5200 chars > 2048 tokens + { role: "assistant", content: "Response ".repeat(400) }, + { role: "user", content: "Follow up" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-haiku-3-5"); + + // Should cache with Haiku when content is large enough + const hasCaching = result.some((msg) => msg.providerOptions?.anthropic?.cacheControl); + expect(hasCaching).toBe(true); + }); + + test("should handle messages with array content", () => { + const messages: ModelMessage[] = [ + { + role: "user", + content: [ + { type: "text", text: "Here is a long document. ".repeat(200) }, + { type: "text", text: "Additional context. ".repeat(100) }, + ], + }, + { role: "assistant", content: "I understand" }, + { role: "user", content: "What did I say?" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Should handle multi-part content and apply caching + expect(result[1].providerOptions?.anthropic?.cacheControl).toEqual({ + type: "ephemeral", + ttl: "5m", + }); + }); + + test("should preserve existing providerOptions", () => { + const messages: ModelMessage[] = [ + { + role: "system", + content: "System prompt with detailed instructions. ".repeat(300), // ~12600 chars > 1024 tokens + providerOptions: { + anthropic: { + customOption: "value", + }, + }, + }, + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "Continue" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Should preserve existing options while adding cacheControl + const anthropicOptions = result[0].providerOptions?.anthropic as Record; + expect(anthropicOptions?.customOption).toBe("value"); + expect(anthropicOptions?.cacheControl).toBeDefined(); + }); + + test("should not exceed 4 breakpoint limit", () => { + // Create a very long conversation + const messages: ModelMessage[] = [{ role: "system", content: "System ".repeat(300) }]; + + // Add 20 message pairs + for (let i = 0; i < 20; i++) { + messages.push({ role: "user", content: `User message ${i} `.repeat(100) }); + messages.push({ role: "assistant", content: `Assistant ${i} `.repeat(100) }); + } + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Count breakpoints + const breakpointCount = result.filter( + (msg) => msg.providerOptions?.anthropic?.cacheControl + ).length; + + // Should never exceed 4 breakpoints + expect(breakpointCount).toBeLessThanOrEqual(4); + expect(breakpointCount).toBeGreaterThan(0); + }); + + test("should place 1h TTL before 5m TTL", () => { + const messages: ModelMessage[] = [ + { role: "system", content: "System instructions. ".repeat(200) }, + { role: "user", content: "Q1 ".repeat(100) }, + { role: "assistant", content: "A1 ".repeat(100) }, + { role: "user", content: "Q2 ".repeat(100) }, + { role: "assistant", content: "A2 ".repeat(100) }, + { role: "user", content: "Q3" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Collect breakpoints with their TTLs + const breakpoints = result + .map((msg, idx) => { + const cacheControl = msg.providerOptions?.anthropic?.cacheControl; + const ttl = + cacheControl && typeof cacheControl === "object" && "ttl" in cacheControl + ? (cacheControl.ttl as "5m" | "1h" | undefined) + : undefined; + return { idx, ttl }; + }) + .filter((bp): bp is { idx: number; ttl: "5m" | "1h" } => bp.ttl !== undefined); + + // Find first 1h and first 5m + const firstOneHour = breakpoints.find((bp) => bp.ttl === "1h"); + const firstFiveMin = breakpoints.find((bp) => bp.ttl === "5m"); + + // If both exist, 1h should come before 5m + if (firstOneHour && firstFiveMin) { + expect(firstOneHour.idx).toBeLessThan(firstFiveMin.idx); + } + }); + + test("should handle image content in token estimation", () => { + const messages: ModelMessage[] = [ + { + role: "user", + content: [ + { type: "text", text: "Analyze this image: ".repeat(100) }, + { type: "image", image: "data:image/png;base64,..." }, + ], + }, + { role: "assistant", content: "I see a test image" }, + { role: "user", content: "What else?" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Should account for image tokens and apply caching + const hasCaching = result.some((msg) => msg.providerOptions?.anthropic?.cacheControl); + expect(hasCaching).toBe(true); + }); + + test("should handle edge case with exact minimum tokens", () => { + // Create content that's exactly at the threshold (1024 tokens ≈ 4096 chars) + const messages: ModelMessage[] = [ + { role: "user", content: "x".repeat(4096) }, + { role: "assistant", content: "ok" }, + { role: "user", content: "continue" }, + ]; + + const result = applyCacheControl(messages, "anthropic:claude-sonnet-4-5"); + + // Should apply caching at the threshold + const hasCaching = result.some((msg) => msg.providerOptions?.anthropic?.cacheControl); + expect(hasCaching).toBe(true); + }); +}); diff --git a/src/utils/ai/cacheStrategy.ts b/src/utils/ai/cacheStrategy.ts index 7939ec5a9..241b79870 100644 --- a/src/utils/ai/cacheStrategy.ts +++ b/src/utils/ai/cacheStrategy.ts @@ -1,8 +1,171 @@ import type { ModelMessage } from "ai"; +/** + * Minimum token counts required for caching different Anthropic models + * Based on https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching + */ +const MIN_CACHE_TOKENS = { + haiku: 2048, // Claude Haiku 3.5 and 3 + default: 1024, // Claude Opus 4.1, Opus 4, Sonnet 4.5, Sonnet 4, Sonnet 3.7, Opus 3 +} as const; + +/** + * Maximum number of cache breakpoints allowed by Anthropic + */ +const MAX_CACHE_BREAKPOINTS = 4; + +/** + * Rough estimation of tokens from characters + * Uses ~4 chars per token as a conservative estimate for Anthropic models + * This is intentionally conservative - better to cache more than miss opportunities + */ +function estimateTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +/** + * Estimate tokens in a ModelMessage including all content + */ +function estimateMessageTokens(message: ModelMessage): number { + let total = 0; + + // Count text content + if (typeof message.content === "string") { + total += estimateTokens(message.content); + } else if (Array.isArray(message.content)) { + for (const part of message.content) { + if (part.type === "text") { + total += estimateTokens(part.text); + } else if (part.type === "image") { + // Images have fixed token cost - conservative estimate + total += 1000; + } + } + } + + // Add overhead for message structure (role, formatting, etc) + total += 10; + + return total; +} + +/** + * Get minimum cacheable token count for a model + */ +function getMinCacheTokens(modelString: string): number { + if (modelString.includes("haiku")) { + return MIN_CACHE_TOKENS.haiku; + } + return MIN_CACHE_TOKENS.default; +} + +/** + * Calculate cumulative token counts for messages from start to each position + */ +function calculateCumulativeTokens(messages: ModelMessage[]): number[] { + const cumulative: number[] = []; + let total = 0; + + for (const message of messages) { + total += estimateMessageTokens(message); + cumulative.push(total); + } + + return cumulative; +} + +/** + * Determine optimal cache breakpoint positions using a multi-tier strategy + * + * Strategy: + * 1. System messages (1h TTL) - Most stable, rarely change + * 2. Tool definitions (1h TTL) - Stable within a session + * 3. Conversation history excluding last few turns (5m TTL) - Changes gradually + * 4. Recent history excluding current user message (5m TTL) - Fastest changing + * + * Returns array of {index, ttl} for messages to mark with cache control + */ +function determineBreakpoints( + messages: ModelMessage[], + minTokens: number +): Array<{ index: number; ttl: "5m" | "1h" }> { + if (messages.length < 2) { + return []; + } + + const breakpoints: Array<{ index: number; ttl: "5m" | "1h" }> = []; + const cumulative = calculateCumulativeTokens(messages); + + // Find system messages (prefer 1h cache for stability) + // Use manual loop instead of findLastIndex for ES2021 compatibility + let lastSystemIndex = -1; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === "system") { + lastSystemIndex = i; + break; + } + } + if (lastSystemIndex >= 0 && cumulative[lastSystemIndex] >= minTokens) { + breakpoints.push({ index: lastSystemIndex, ttl: "1h" }); + } + + // If no system message cached yet and we have tools, cache after tools + // Note: In Anthropic's API, tools appear before system in the hierarchy + // but in ModelMessage format they're typically in early messages + if (breakpoints.length === 0) { + // Find first message with substantial content (likely includes tools/setup) + for (let i = 0; i < Math.min(3, messages.length - 1); i++) { + if (cumulative[i] >= minTokens) { + breakpoints.push({ index: i, ttl: "1h" }); + break; + } + } + } + + // Add mid-conversation breakpoint (5m TTL) + // Cache conversation history but not the most recent exchanges + if (messages.length >= 6 && breakpoints.length < MAX_CACHE_BREAKPOINTS) { + const midPoint = Math.floor((messages.length - 2) * 0.6); + // Ensure this breakpoint is after any 1h breakpoint and has enough tokens + const lastBreakpointIndex = breakpoints[breakpoints.length - 1]?.index ?? -1; + if (midPoint > lastBreakpointIndex && cumulative[midPoint] >= minTokens) { + breakpoints.push({ index: midPoint, ttl: "5m" }); + } + } + + // Always try to cache everything except the current user message (5m TTL) + // This is the most frequently refreshed cache + const lastCacheIndex = messages.length - 2; + const lastBreakpointIndex = breakpoints[breakpoints.length - 1]?.index ?? -1; + + if ( + lastCacheIndex > lastBreakpointIndex && + cumulative[lastCacheIndex] >= minTokens && + breakpoints.length < MAX_CACHE_BREAKPOINTS + ) { + breakpoints.push({ index: lastCacheIndex, ttl: "5m" }); + } + + return breakpoints; +} + /** * Apply cache control to messages for Anthropic models - * MVP: Single cache breakpoint before the last message + * + * Uses a multi-tier caching strategy: + * - System messages and tools: 1h TTL (most stable) + * - Mid-conversation: 5m TTL (moderate stability) + * - Recent history: 5m TTL (frequently updated) + * + * Respects Anthropic's constraints: + * - Maximum 4 cache breakpoints + * - Minimum token thresholds (1024 for Sonnet/Opus, 2048 for Haiku) + * - 1h segments must appear before 5m segments + * + * Benefits: + * - Up to 90% cost reduction on cached content (10% of base price) + * - Up to 85% latency reduction for cached prompts + * - Optimal use of 4 breakpoint limit */ export function applyCacheControl(messages: ModelMessage[], modelString: string): ModelMessage[] { // Only apply cache control for Anthropic models @@ -15,19 +178,27 @@ export function applyCacheControl(messages: ModelMessage[], modelString: string) return messages; } - // Add cache breakpoint at the second-to-last message - // This caches everything up to (but not including) the current user message - const cacheIndex = messages.length - 2; + const minTokens = getMinCacheTokens(modelString); + const breakpoints = determineBreakpoints(messages, minTokens); + + // No valid breakpoints found + if (breakpoints.length === 0) { + return messages; + } + // Apply cache control at determined breakpoints return messages.map((msg, index) => { - if (index === cacheIndex) { + const breakpoint = breakpoints.find((bp) => bp.index === index); + if (breakpoint) { return { ...msg, providerOptions: { + ...msg.providerOptions, anthropic: { + ...msg.providerOptions?.anthropic, cacheControl: { type: "ephemeral" as const, - ttl: "5m", + ttl: breakpoint.ttl, }, }, },