From a2d90f01a73b12335e83a82b510ad977ca641dde Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 5 Aug 2025 18:24:32 -0600 Subject: [PATCH] ai fallback on ai sdk v5 --- .../ai/src/utils/models/ai-fallback.test.ts | 80 ++-- packages/ai/src/utils/models/ai-fallback.ts | 419 +++++++++--------- 2 files changed, 238 insertions(+), 261 deletions(-) diff --git a/packages/ai/src/utils/models/ai-fallback.test.ts b/packages/ai/src/utils/models/ai-fallback.test.ts index 7630a8742..189395fe1 100644 --- a/packages/ai/src/utils/models/ai-fallback.test.ts +++ b/packages/ai/src/utils/models/ai-fallback.test.ts @@ -19,11 +19,11 @@ test('doStream switches models on error', async () => { stream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'stream-start', warnings: [] }); - controller.enqueue({ type: 'text-delta', textDelta: 'Hello from fallback' }); + controller.enqueue({ type: 'text-delta', id: '1', delta: 'Hello from fallback' }); controller.enqueue({ type: 'finish', finishReason: 'stop', - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, }); controller.close(); }, @@ -38,8 +38,7 @@ test('doStream switches models on error', async () => { }); const result = await fallback.doStream({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); // Read the stream @@ -53,9 +52,9 @@ test('doStream switches models on error', async () => { } expect(chunks).toHaveLength(3); - expect(chunks[0].type).toBe('stream-start'); - expect(chunks[1]).toEqual({ type: 'text-delta', textDelta: 'Hello from fallback' }); - expect(chunks[2].type).toBe('finish'); + expect(chunks[0]!.type).toBe('stream-start'); + expect(chunks[1]).toEqual({ type: 'text-delta', id: '1', delta: 'Hello from fallback' }); + expect(chunks[2]!.type).toBe('finish'); expect(fallback.currentModelIndex).toBe(1); expect(onError).toHaveBeenCalledWith( @@ -85,11 +84,11 @@ test('doStream handles error during streaming', async () => { stream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'stream-start', warnings: [] }); - controller.enqueue({ type: 'text-delta', textDelta: 'Fallback response' }); + controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback response' }); controller.enqueue({ type: 'finish', finishReason: 'stop', - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, }); controller.close(); }, @@ -104,8 +103,7 @@ test('doStream handles error during streaming', async () => { }); const result = await fallback.doStream({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); // Read the stream @@ -118,9 +116,7 @@ test('doStream handles error during streaming', async () => { chunks.push(value); } - expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Fallback response')).toBe( - true - ); + expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Fallback response')).toBe(true); expect(fallback.currentModelIndex).toBe(1); }); @@ -133,7 +129,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => { stream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'stream-start', warnings: [] }); - controller.enqueue({ type: 'text-delta', textDelta: 'Partial output' }); + controller.enqueue({ type: 'text-delta', id: '1', delta: 'Partial output' }); controller.error(new Error('Stream interrupted')); }, }), @@ -147,11 +143,11 @@ test('doStream with partial output and retryAfterOutput=true', async () => { stream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'stream-start', warnings: [] }); - controller.enqueue({ type: 'text-delta', textDelta: 'Fallback continuation' }); + controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback continuation' }); controller.enqueue({ type: 'finish', finishReason: 'stop', - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, }); controller.close(); }, @@ -167,8 +163,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => { }); const result = await fallback.doStream({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); // Read the stream @@ -185,7 +180,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => { // The partial output from the failed model is lost const textChunks = chunks.filter((c) => c.type === 'text-delta'); expect(textChunks).toHaveLength(1); - expect(textChunks[0].textDelta).toBe('Fallback continuation'); + expect(textChunks[0]!.delta).toBe('Fallback continuation'); // Should switch models after error expect(fallback.currentModelIndex).toBe(1); @@ -218,11 +213,11 @@ test('doStream handles error in stream part', async () => { stream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'stream-start', warnings: [] }); - controller.enqueue({ type: 'text-delta', textDelta: 'Success' }); + controller.enqueue({ type: 'text-delta', id: '1', delta: 'Success' }); controller.enqueue({ type: 'finish', finishReason: 'stop', - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, }); controller.close(); }, @@ -240,8 +235,7 @@ test('doStream handles error in stream part', async () => { }); const result = await fallback.doStream({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); // Read the stream @@ -255,7 +249,7 @@ test('doStream handles error in stream part', async () => { } // Should have switched to model 2 and gotten success - expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Success')).toBe(true); + expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Success')).toBe(true); expect(encounteredErrors).toHaveLength(1); expect(encounteredErrors[0]).toBe('Overloaded'); expect(fallback.currentModelIndex).toBe(1); @@ -276,7 +270,7 @@ test('doGenerate switches models on error', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Response from fallback model' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -289,8 +283,7 @@ test('doGenerate switches models on error', async () => { expect(fallback.modelId).toBe('failing-model'); const result = await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' }); @@ -322,7 +315,7 @@ test('cycles through all models until one works', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Success from model 3' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -332,8 +325,7 @@ test('cycles through all models until one works', async () => { }); const result = await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' }); @@ -362,8 +354,7 @@ test('throws error when all models fail', async () => { await expect( fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }) ).rejects.toThrow('Model 2 capacity reached'); @@ -391,8 +382,7 @@ test('model reset interval resets to first model', async () => { // Trigger reset check await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(fallback.currentModelIndex).toBe(0); @@ -419,7 +409,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Should not reach here' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -432,8 +422,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => { // Should not retry because error doesn't match await expect( fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }) ).rejects.toThrow('specific-error-that-should-not-retry'); @@ -460,7 +449,7 @@ test('handles non-existent model error', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Fallback response' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -475,8 +464,7 @@ test('handles non-existent model error', async () => { }); const result = await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' }); @@ -502,7 +490,7 @@ test('handles API key errors', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Success with correct key' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -512,8 +500,7 @@ test('handles API key errors', async () => { }); const result = await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' }); @@ -538,7 +525,7 @@ test('handles rate limit errors', async () => { doGenerate: async () => ({ content: [{ type: 'text', text: 'Response from available model' }], finishReason: 'stop' as const, - usage: { inputTokens: 10, outputTokens: 5 }, + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, warnings: [], }), }); @@ -548,8 +535,7 @@ test('handles rate limit errors', async () => { }); const result = await fallback.doGenerate({ - prompt: { system: 'test', messages: [] }, - mode: { type: 'regular' }, + prompt: [], }); expect(attempts).toBe(1); diff --git a/packages/ai/src/utils/models/ai-fallback.ts b/packages/ai/src/utils/models/ai-fallback.ts index fe32f10d3..c1a5a99c3 100644 --- a/packages/ai/src/utils/models/ai-fallback.ts +++ b/packages/ai/src/utils/models/ai-fallback.ts @@ -1,254 +1,245 @@ import type { - LanguageModelV2, - LanguageModelV2CallOptions, - LanguageModelV2CallWarning, - LanguageModelV2Content, - LanguageModelV2FinishReason, - LanguageModelV2StreamPart, - LanguageModelV2Usage, - SharedV2ProviderMetadata, -} from "@ai-sdk/provider"; + LanguageModelV2, + LanguageModelV2CallOptions, + LanguageModelV2CallWarning, + LanguageModelV2Content, + LanguageModelV2FinishReason, + LanguageModelV2StreamPart, + LanguageModelV2Usage, + SharedV2ProviderMetadata, +} from '@ai-sdk/provider'; + +interface RetryableError extends Error { + statusCode?: number; +} interface Settings { - models: LanguageModelV2[]; - retryAfterOutput?: boolean; - modelResetInterval?: number; - shouldRetryThisError?: (error: Error) => boolean; - onError?: (error: Error, modelId: string) => void | Promise; + models: LanguageModelV2[]; + retryAfterOutput?: boolean; + modelResetInterval?: number; + shouldRetryThisError?: (error: RetryableError) => boolean; + onError?: (error: RetryableError, modelId: string) => void | Promise; } export function createFallback(settings: Settings): FallbackModel { - return new FallbackModel(settings); + return new FallbackModel(settings); } const retryableStatusCodes = [ - 401, // wrong API key - 403, // permission error, like cannot access model or from a non accessible region - 408, // request timeout - 409, // conflict - 413, // payload too large - 429, // too many requests/rate limits - 500, // server error (and above) + 401, // wrong API key + 403, // permission error, like cannot access model or from a non accessible region + 408, // request timeout + 409, // conflict + 413, // payload too large + 429, // too many requests/rate limits + 500, // server error (and above) ]; // Common error messages/codes that indicate server overload or temporary issues const retryableErrors = [ - "overloaded", - "service unavailable", - "bad gateway", - "too many requests", - "internal server error", - "gateway timeout", - "rate_limit", - "wrong-key", - "unexpected", - "capacity", - "timeout", - "server_error", - "429", // Too Many Requests - "500", // Internal Server Error - "502", // Bad Gateway - "503", // Service Unavailable - "504", // Gateway Timeout + 'overloaded', + 'service unavailable', + 'bad gateway', + 'too many requests', + 'internal server error', + 'gateway timeout', + 'rate_limit', + 'wrong-key', + 'unexpected', + 'capacity', + 'timeout', + 'server_error', + '429', // Too Many Requests + '500', // Internal Server Error + '502', // Bad Gateway + '503', // Service Unavailable + '504', // Gateway Timeout ]; -function defaultShouldRetryThisError(error: any): boolean { - const statusCode = error?.statusCode; +function defaultShouldRetryThisError(error: RetryableError): boolean { + const statusCode = error?.statusCode; - if ( - statusCode && - (retryableStatusCodes.includes(statusCode) || statusCode > 500) - ) { - return true; - } + if (statusCode && (retryableStatusCodes.includes(statusCode) || statusCode > 500)) { + return true; + } - if (error?.message) { - const errorString = error.message.toLowerCase() || ""; - return retryableErrors.some((errType) => errorString.includes(errType)); - } - if (error && typeof error === "object") { - const errorString = JSON.stringify(error).toLowerCase() || ""; - return retryableErrors.some((errType) => errorString.includes(errType)); - } - return false; + if (error?.message) { + const errorString = error.message.toLowerCase() || ''; + return retryableErrors.some((errType) => errorString.includes(errType)); + } + if (error && typeof error === 'object') { + const errorString = JSON.stringify(error).toLowerCase() || ''; + return retryableErrors.some((errType) => errorString.includes(errType)); + } + return false; } export class FallbackModel implements LanguageModelV2 { - readonly specificationVersion = "v2"; + readonly specificationVersion = 'v2'; - get supportedUrls(): - | Record - | PromiseLike> { - return this.getCurrentModel().supportedUrls; - } + get supportedUrls(): Record | PromiseLike> { + return this.getCurrentModel().supportedUrls; + } - get modelId(): string { - return this.getCurrentModel().modelId; - } - readonly settings: Settings; + get modelId(): string { + return this.getCurrentModel().modelId; + } + readonly settings: Settings; - currentModelIndex = 0; - private lastModelReset: number = Date.now(); - private readonly modelResetInterval: number; - retryAfterOutput: boolean; - constructor(settings: Settings) { - this.settings = settings; - this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms - this.retryAfterOutput = settings.retryAfterOutput ?? true; + currentModelIndex = 0; + private lastModelReset: number = Date.now(); + private readonly modelResetInterval: number; + retryAfterOutput: boolean; + constructor(settings: Settings) { + this.settings = settings; + this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms + this.retryAfterOutput = settings.retryAfterOutput ?? true; - if (!this.settings.models[this.currentModelIndex]) { - throw new Error("No models available in settings"); - } - } + if (!this.settings.models[this.currentModelIndex]) { + throw new Error('No models available in settings'); + } + } - get provider(): string { - return this.getCurrentModel().provider; - } + get provider(): string { + return this.getCurrentModel().provider; + } - private getCurrentModel(): LanguageModelV2 { - const model = this.settings.models[this.currentModelIndex]; - if (!model) { - throw new Error(`No model available at index ${this.currentModelIndex}`); - } - return model; - } + private getCurrentModel(): LanguageModelV2 { + const model = this.settings.models[this.currentModelIndex]; + if (!model) { + throw new Error(`No model available at index ${this.currentModelIndex}`); + } + return model; + } - private checkAndResetModel() { - const now = Date.now(); - if ( - now - this.lastModelReset >= this.modelResetInterval && - this.currentModelIndex !== 0 - ) { - this.currentModelIndex = 0; - this.lastModelReset = now; - } - } + private checkAndResetModel() { + const now = Date.now(); + if (now - this.lastModelReset >= this.modelResetInterval && this.currentModelIndex !== 0) { + this.currentModelIndex = 0; + this.lastModelReset = now; + } + } - private switchToNextModel() { - this.currentModelIndex = - (this.currentModelIndex + 1) % this.settings.models.length; - } + private switchToNextModel() { + this.currentModelIndex = (this.currentModelIndex + 1) % this.settings.models.length; + } - private async retry(fn: () => PromiseLike): Promise { - let lastError: Error | undefined; - const initialModel = this.currentModelIndex; + private async retry(fn: () => PromiseLike): Promise { + let lastError: RetryableError | undefined; + const initialModel = this.currentModelIndex; - do { - try { - return await fn(); - } catch (error) { - lastError = error as Error; - // Only retry if it's a server/capacity error - const shouldRetry = - this.settings.shouldRetryThisError || defaultShouldRetryThisError; - if (!shouldRetry(lastError)) { - throw lastError; - } + do { + try { + return await fn(); + } catch (error) { + lastError = error as RetryableError; + // Only retry if it's a server/capacity error + const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError; + if (!shouldRetry(lastError)) { + throw lastError; + } - if (this.settings.onError) { - await this.settings.onError(lastError, this.modelId); - } - this.switchToNextModel(); + if (this.settings.onError) { + await this.settings.onError(lastError, this.modelId); + } + this.switchToNextModel(); - // If we've tried all models, throw the last error - if (this.currentModelIndex === initialModel) { - throw lastError; - } - } - } while (true); - } + // If we've tried all models, throw the last error + if (this.currentModelIndex === initialModel) { + throw lastError; + } + } + } while (this.currentModelIndex !== initialModel); - doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{ - content: LanguageModelV2Content[]; - finishReason: LanguageModelV2FinishReason; - usage: LanguageModelV2Usage; - providerMetadata?: SharedV2ProviderMetadata; - request?: { body?: unknown }; - response?: { - headers?: Record; - id?: string; - timestamp?: Date; - modelId?: string; - }; - warnings: LanguageModelV2CallWarning[]; - }> { - this.checkAndResetModel(); - return this.retry(() => this.getCurrentModel().doGenerate(options)); - } + // This should never be reached, but TypeScript requires it + throw lastError || new Error('Retry failed'); + } - doStream(options: LanguageModelV2CallOptions): PromiseLike<{ - stream: ReadableStream; - request?: { body?: unknown }; - response?: { headers?: Record }; - }> { - this.checkAndResetModel(); - const self = this; - const shouldRetry = - this.settings.shouldRetryThisError || defaultShouldRetryThisError; - return this.retry(async () => { - const result = await self.getCurrentModel().doStream(options); + doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{ + content: LanguageModelV2Content[]; + finishReason: LanguageModelV2FinishReason; + usage: LanguageModelV2Usage; + providerMetadata?: SharedV2ProviderMetadata; + request?: { body?: unknown }; + response?: { + headers?: Record; + id?: string; + timestamp?: Date; + modelId?: string; + }; + warnings: LanguageModelV2CallWarning[]; + }> { + this.checkAndResetModel(); + return this.retry(() => this.getCurrentModel().doGenerate(options)); + } - let hasStreamedAny = false; - // Wrap the stream to handle errors and switch providers if needed - const wrappedStream = new ReadableStream({ - async start(controller) { - try { - const reader = result.stream.getReader(); + doStream(options: LanguageModelV2CallOptions): PromiseLike<{ + stream: ReadableStream; + request?: { body?: unknown }; + response?: { headers?: Record }; + }> { + this.checkAndResetModel(); + const self = this; + const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError; + return this.retry(async () => { + const result = await self.getCurrentModel().doStream(options); - while (true) { - const result = await reader.read(); + let hasStreamedAny = false; + // Wrap the stream to handle errors and switch providers if needed + const wrappedStream = new ReadableStream({ + async start(controller) { + try { + const reader = result.stream.getReader(); - const { done, value } = result; - if ( - !hasStreamedAny && - value && - typeof value === "object" && - "error" in value - ) { - const error = value.error as any; - if (shouldRetry(error)) { - throw error; - } - } + while (true) { + const result = await reader.read(); - if (done) break; - controller.enqueue(value); + const { done, value } = result; + if (!hasStreamedAny && value && typeof value === 'object' && 'error' in value) { + const error = value.error as RetryableError; + if (shouldRetry(error)) { + throw error; + } + } - if (value?.type !== "stream-start") { - hasStreamedAny = true; - } - } - controller.close(); - } catch (error) { - if (self.settings.onError) { - await self.settings.onError(error as Error, self.modelId); - } - if (!hasStreamedAny || self.retryAfterOutput) { - // If nothing was streamed yet, switch models and retry - self.switchToNextModel(); - try { - const nextResult = await self.doStream(options); - const nextReader = nextResult.stream.getReader(); - while (true) { - const { done, value } = await nextReader.read(); - if (done) break; - controller.enqueue(value); - } - controller.close(); - } catch (nextError) { - controller.error(nextError); - } - return; - } - controller.error(error); - } - }, - }); + if (done) break; + controller.enqueue(value); - return { - stream: wrappedStream, - ...(result.request && { request: result.request }), - ...(result.response && { response: result.response }), - }; - }); - } + if (value?.type !== 'stream-start') { + hasStreamedAny = true; + } + } + controller.close(); + } catch (error) { + if (self.settings.onError) { + await self.settings.onError(error as RetryableError, self.modelId); + } + if (!hasStreamedAny || self.retryAfterOutput) { + // If nothing was streamed yet, switch models and retry + self.switchToNextModel(); + try { + const nextResult = await self.doStream(options); + const nextReader = nextResult.stream.getReader(); + while (true) { + const { done, value } = await nextReader.read(); + if (done) break; + controller.enqueue(value); + } + controller.close(); + } catch (nextError) { + controller.error(nextError); + } + return; + } + controller.error(error); + } + }, + }); + + return { + stream: wrappedStream, + ...(result.request && { request: result.request }), + ...(result.response && { response: result.response }), + }; + }); + } }