ai fallback on ai sdk v5

This commit is contained in:
dal 2025-08-05 18:24:32 -06:00
parent fcbe1838a1
commit a2d90f01a7
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 238 additions and 261 deletions

View File

@ -19,11 +19,11 @@ test('doStream switches models on error', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({ stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) { start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] }); controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Hello from fallback' }); controller.enqueue({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
controller.enqueue({ controller.enqueue({
type: 'finish', type: 'finish',
finishReason: 'stop', finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
}); });
controller.close(); controller.close();
}, },
@ -38,8 +38,7 @@ test('doStream switches models on error', async () => {
}); });
const result = await fallback.doStream({ const result = await fallback.doStream({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
// Read the stream // Read the stream
@ -53,9 +52,9 @@ test('doStream switches models on error', async () => {
} }
expect(chunks).toHaveLength(3); expect(chunks).toHaveLength(3);
expect(chunks[0].type).toBe('stream-start'); expect(chunks[0]!.type).toBe('stream-start');
expect(chunks[1]).toEqual({ type: 'text-delta', textDelta: 'Hello from fallback' }); expect(chunks[1]).toEqual({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
expect(chunks[2].type).toBe('finish'); expect(chunks[2]!.type).toBe('finish');
expect(fallback.currentModelIndex).toBe(1); expect(fallback.currentModelIndex).toBe(1);
expect(onError).toHaveBeenCalledWith( expect(onError).toHaveBeenCalledWith(
@ -85,11 +84,11 @@ test('doStream handles error during streaming', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({ stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) { start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] }); controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback response' }); controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback response' });
controller.enqueue({ controller.enqueue({
type: 'finish', type: 'finish',
finishReason: 'stop', finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
}); });
controller.close(); controller.close();
}, },
@ -104,8 +103,7 @@ test('doStream handles error during streaming', async () => {
}); });
const result = await fallback.doStream({ const result = await fallback.doStream({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
// Read the stream // Read the stream
@ -118,9 +116,7 @@ test('doStream handles error during streaming', async () => {
chunks.push(value); chunks.push(value);
} }
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Fallback response')).toBe( expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Fallback response')).toBe(true);
true
);
expect(fallback.currentModelIndex).toBe(1); expect(fallback.currentModelIndex).toBe(1);
}); });
@ -133,7 +129,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({ stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) { start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] }); controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Partial output' }); controller.enqueue({ type: 'text-delta', id: '1', delta: 'Partial output' });
controller.error(new Error('Stream interrupted')); controller.error(new Error('Stream interrupted'));
}, },
}), }),
@ -147,11 +143,11 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({ stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) { start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] }); controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback continuation' }); controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback continuation' });
controller.enqueue({ controller.enqueue({
type: 'finish', type: 'finish',
finishReason: 'stop', finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
}); });
controller.close(); controller.close();
}, },
@ -167,8 +163,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
}); });
const result = await fallback.doStream({ const result = await fallback.doStream({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
// Read the stream // Read the stream
@ -185,7 +180,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
// The partial output from the failed model is lost // The partial output from the failed model is lost
const textChunks = chunks.filter((c) => c.type === 'text-delta'); const textChunks = chunks.filter((c) => c.type === 'text-delta');
expect(textChunks).toHaveLength(1); expect(textChunks).toHaveLength(1);
expect(textChunks[0].textDelta).toBe('Fallback continuation'); expect(textChunks[0]!.delta).toBe('Fallback continuation');
// Should switch models after error // Should switch models after error
expect(fallback.currentModelIndex).toBe(1); expect(fallback.currentModelIndex).toBe(1);
@ -218,11 +213,11 @@ test('doStream handles error in stream part', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({ stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) { start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] }); controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Success' }); controller.enqueue({ type: 'text-delta', id: '1', delta: 'Success' });
controller.enqueue({ controller.enqueue({
type: 'finish', type: 'finish',
finishReason: 'stop', finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
}); });
controller.close(); controller.close();
}, },
@ -240,8 +235,7 @@ test('doStream handles error in stream part', async () => {
}); });
const result = await fallback.doStream({ const result = await fallback.doStream({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
// Read the stream // Read the stream
@ -255,7 +249,7 @@ test('doStream handles error in stream part', async () => {
} }
// Should have switched to model 2 and gotten success // Should have switched to model 2 and gotten success
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Success')).toBe(true); expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Success')).toBe(true);
expect(encounteredErrors).toHaveLength(1); expect(encounteredErrors).toHaveLength(1);
expect(encounteredErrors[0]).toBe('Overloaded'); expect(encounteredErrors[0]).toBe('Overloaded');
expect(fallback.currentModelIndex).toBe(1); expect(fallback.currentModelIndex).toBe(1);
@ -276,7 +270,7 @@ test('doGenerate switches models on error', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Response from fallback model' }], content: [{ type: 'text', text: 'Response from fallback model' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -289,8 +283,7 @@ test('doGenerate switches models on error', async () => {
expect(fallback.modelId).toBe('failing-model'); expect(fallback.modelId).toBe('failing-model');
const result = await fallback.doGenerate({ const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' }); expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' });
@ -322,7 +315,7 @@ test('cycles through all models until one works', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Success from model 3' }], content: [{ type: 'text', text: 'Success from model 3' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -332,8 +325,7 @@ test('cycles through all models until one works', async () => {
}); });
const result = await fallback.doGenerate({ const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' }); expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' });
@ -362,8 +354,7 @@ test('throws error when all models fail', async () => {
await expect( await expect(
fallback.doGenerate({ fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}) })
).rejects.toThrow('Model 2 capacity reached'); ).rejects.toThrow('Model 2 capacity reached');
@ -391,8 +382,7 @@ test('model reset interval resets to first model', async () => {
// Trigger reset check // Trigger reset check
await fallback.doGenerate({ await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(fallback.currentModelIndex).toBe(0); expect(fallback.currentModelIndex).toBe(0);
@ -419,7 +409,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Should not reach here' }], content: [{ type: 'text', text: 'Should not reach here' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -432,8 +422,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
// Should not retry because error doesn't match // Should not retry because error doesn't match
await expect( await expect(
fallback.doGenerate({ fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}) })
).rejects.toThrow('specific-error-that-should-not-retry'); ).rejects.toThrow('specific-error-that-should-not-retry');
@ -460,7 +449,7 @@ test('handles non-existent model error', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Fallback response' }], content: [{ type: 'text', text: 'Fallback response' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -475,8 +464,7 @@ test('handles non-existent model error', async () => {
}); });
const result = await fallback.doGenerate({ const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' }); expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' });
@ -502,7 +490,7 @@ test('handles API key errors', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Success with correct key' }], content: [{ type: 'text', text: 'Success with correct key' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -512,8 +500,7 @@ test('handles API key errors', async () => {
}); });
const result = await fallback.doGenerate({ const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' }); expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' });
@ -538,7 +525,7 @@ test('handles rate limit errors', async () => {
doGenerate: async () => ({ doGenerate: async () => ({
content: [{ type: 'text', text: 'Response from available model' }], content: [{ type: 'text', text: 'Response from available model' }],
finishReason: 'stop' as const, finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 }, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [], warnings: [],
}), }),
}); });
@ -548,8 +535,7 @@ test('handles rate limit errors', async () => {
}); });
const result = await fallback.doGenerate({ const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] }, prompt: [],
mode: { type: 'regular' },
}); });
expect(attempts).toBe(1); expect(attempts).toBe(1);

View File

@ -7,14 +7,18 @@ import type {
LanguageModelV2StreamPart, LanguageModelV2StreamPart,
LanguageModelV2Usage, LanguageModelV2Usage,
SharedV2ProviderMetadata, SharedV2ProviderMetadata,
} from "@ai-sdk/provider"; } from '@ai-sdk/provider';
interface RetryableError extends Error {
statusCode?: number;
}
interface Settings { interface Settings {
models: LanguageModelV2[]; models: LanguageModelV2[];
retryAfterOutput?: boolean; retryAfterOutput?: boolean;
modelResetInterval?: number; modelResetInterval?: number;
shouldRetryThisError?: (error: Error) => boolean; shouldRetryThisError?: (error: RetryableError) => boolean;
onError?: (error: Error, modelId: string) => void | Promise<void>; onError?: (error: RetryableError, modelId: string) => void | Promise<void>;
} }
export function createFallback(settings: Settings): FallbackModel { export function createFallback(settings: Settings): FallbackModel {
@ -32,52 +36,47 @@ const retryableStatusCodes = [
]; ];
// Common error messages/codes that indicate server overload or temporary issues // Common error messages/codes that indicate server overload or temporary issues
const retryableErrors = [ const retryableErrors = [
"overloaded", 'overloaded',
"service unavailable", 'service unavailable',
"bad gateway", 'bad gateway',
"too many requests", 'too many requests',
"internal server error", 'internal server error',
"gateway timeout", 'gateway timeout',
"rate_limit", 'rate_limit',
"wrong-key", 'wrong-key',
"unexpected", 'unexpected',
"capacity", 'capacity',
"timeout", 'timeout',
"server_error", 'server_error',
"429", // Too Many Requests '429', // Too Many Requests
"500", // Internal Server Error '500', // Internal Server Error
"502", // Bad Gateway '502', // Bad Gateway
"503", // Service Unavailable '503', // Service Unavailable
"504", // Gateway Timeout '504', // Gateway Timeout
]; ];
function defaultShouldRetryThisError(error: any): boolean { function defaultShouldRetryThisError(error: RetryableError): boolean {
const statusCode = error?.statusCode; const statusCode = error?.statusCode;
if ( if (statusCode && (retryableStatusCodes.includes(statusCode) || statusCode > 500)) {
statusCode &&
(retryableStatusCodes.includes(statusCode) || statusCode > 500)
) {
return true; return true;
} }
if (error?.message) { if (error?.message) {
const errorString = error.message.toLowerCase() || ""; const errorString = error.message.toLowerCase() || '';
return retryableErrors.some((errType) => errorString.includes(errType)); return retryableErrors.some((errType) => errorString.includes(errType));
} }
if (error && typeof error === "object") { if (error && typeof error === 'object') {
const errorString = JSON.stringify(error).toLowerCase() || ""; const errorString = JSON.stringify(error).toLowerCase() || '';
return retryableErrors.some((errType) => errorString.includes(errType)); return retryableErrors.some((errType) => errorString.includes(errType));
} }
return false; return false;
} }
export class FallbackModel implements LanguageModelV2 { export class FallbackModel implements LanguageModelV2 {
readonly specificationVersion = "v2"; readonly specificationVersion = 'v2';
get supportedUrls(): get supportedUrls(): Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>> {
| Record<string, RegExp[]>
| PromiseLike<Record<string, RegExp[]>> {
return this.getCurrentModel().supportedUrls; return this.getCurrentModel().supportedUrls;
} }
@ -96,7 +95,7 @@ export class FallbackModel implements LanguageModelV2 {
this.retryAfterOutput = settings.retryAfterOutput ?? true; this.retryAfterOutput = settings.retryAfterOutput ?? true;
if (!this.settings.models[this.currentModelIndex]) { if (!this.settings.models[this.currentModelIndex]) {
throw new Error("No models available in settings"); throw new Error('No models available in settings');
} }
} }
@ -114,32 +113,27 @@ export class FallbackModel implements LanguageModelV2 {
private checkAndResetModel() { private checkAndResetModel() {
const now = Date.now(); const now = Date.now();
if ( if (now - this.lastModelReset >= this.modelResetInterval && this.currentModelIndex !== 0) {
now - this.lastModelReset >= this.modelResetInterval &&
this.currentModelIndex !== 0
) {
this.currentModelIndex = 0; this.currentModelIndex = 0;
this.lastModelReset = now; this.lastModelReset = now;
} }
} }
private switchToNextModel() { private switchToNextModel() {
this.currentModelIndex = this.currentModelIndex = (this.currentModelIndex + 1) % this.settings.models.length;
(this.currentModelIndex + 1) % this.settings.models.length;
} }
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> { private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
let lastError: Error | undefined; let lastError: RetryableError | undefined;
const initialModel = this.currentModelIndex; const initialModel = this.currentModelIndex;
do { do {
try { try {
return await fn(); return await fn();
} catch (error) { } catch (error) {
lastError = error as Error; lastError = error as RetryableError;
// Only retry if it's a server/capacity error // Only retry if it's a server/capacity error
const shouldRetry = const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
if (!shouldRetry(lastError)) { if (!shouldRetry(lastError)) {
throw lastError; throw lastError;
} }
@ -154,7 +148,10 @@ export class FallbackModel implements LanguageModelV2 {
throw lastError; throw lastError;
} }
} }
} while (true); } while (this.currentModelIndex !== initialModel);
// This should never be reached, but TypeScript requires it
throw lastError || new Error('Retry failed');
} }
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{ doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
@ -182,8 +179,7 @@ export class FallbackModel implements LanguageModelV2 {
}> { }> {
this.checkAndResetModel(); this.checkAndResetModel();
const self = this; const self = this;
const shouldRetry = const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
return this.retry(async () => { return this.retry(async () => {
const result = await self.getCurrentModel().doStream(options); const result = await self.getCurrentModel().doStream(options);
@ -198,13 +194,8 @@ export class FallbackModel implements LanguageModelV2 {
const result = await reader.read(); const result = await reader.read();
const { done, value } = result; const { done, value } = result;
if ( if (!hasStreamedAny && value && typeof value === 'object' && 'error' in value) {
!hasStreamedAny && const error = value.error as RetryableError;
value &&
typeof value === "object" &&
"error" in value
) {
const error = value.error as any;
if (shouldRetry(error)) { if (shouldRetry(error)) {
throw error; throw error;
} }
@ -213,14 +204,14 @@ export class FallbackModel implements LanguageModelV2 {
if (done) break; if (done) break;
controller.enqueue(value); controller.enqueue(value);
if (value?.type !== "stream-start") { if (value?.type !== 'stream-start') {
hasStreamedAny = true; hasStreamedAny = true;
} }
} }
controller.close(); controller.close();
} catch (error) { } catch (error) {
if (self.settings.onError) { if (self.settings.onError) {
await self.settings.onError(error as Error, self.modelId); await self.settings.onError(error as RetryableError, self.modelId);
} }
if (!hasStreamedAny || self.retryAfterOutput) { if (!hasStreamedAny || self.retryAfterOutput) {
// If nothing was streamed yet, switch models and retry // If nothing was streamed yet, switch models and retry