ai fallback on ai sdk v5

This commit is contained in:
dal 2025-08-05 18:24:32 -06:00
parent fcbe1838a1
commit a2d90f01a7
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 238 additions and 261 deletions

View File

@ -19,11 +19,11 @@ test('doStream switches models on error', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Hello from fallback' });
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
});
controller.close();
},
@ -38,8 +38,7 @@ test('doStream switches models on error', async () => {
});
const result = await fallback.doStream({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
// Read the stream
@ -53,9 +52,9 @@ test('doStream switches models on error', async () => {
}
expect(chunks).toHaveLength(3);
expect(chunks[0].type).toBe('stream-start');
expect(chunks[1]).toEqual({ type: 'text-delta', textDelta: 'Hello from fallback' });
expect(chunks[2].type).toBe('finish');
expect(chunks[0]!.type).toBe('stream-start');
expect(chunks[1]).toEqual({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
expect(chunks[2]!.type).toBe('finish');
expect(fallback.currentModelIndex).toBe(1);
expect(onError).toHaveBeenCalledWith(
@ -85,11 +84,11 @@ test('doStream handles error during streaming', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback response' });
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback response' });
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
});
controller.close();
},
@ -104,8 +103,7 @@ test('doStream handles error during streaming', async () => {
});
const result = await fallback.doStream({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
// Read the stream
@ -118,9 +116,7 @@ test('doStream handles error during streaming', async () => {
chunks.push(value);
}
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Fallback response')).toBe(
true
);
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Fallback response')).toBe(true);
expect(fallback.currentModelIndex).toBe(1);
});
@ -133,7 +129,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Partial output' });
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Partial output' });
controller.error(new Error('Stream interrupted'));
},
}),
@ -147,11 +143,11 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback continuation' });
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback continuation' });
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
});
controller.close();
},
@ -167,8 +163,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
});
const result = await fallback.doStream({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
// Read the stream
@ -185,7 +180,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
// The partial output from the failed model is lost
const textChunks = chunks.filter((c) => c.type === 'text-delta');
expect(textChunks).toHaveLength(1);
expect(textChunks[0].textDelta).toBe('Fallback continuation');
expect(textChunks[0]!.delta).toBe('Fallback continuation');
// Should switch models after error
expect(fallback.currentModelIndex).toBe(1);
@ -218,11 +213,11 @@ test('doStream handles error in stream part', async () => {
stream: new ReadableStream<LanguageModelV2StreamPart>({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({ type: 'text-delta', textDelta: 'Success' });
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Success' });
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
});
controller.close();
},
@ -240,8 +235,7 @@ test('doStream handles error in stream part', async () => {
});
const result = await fallback.doStream({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
// Read the stream
@ -255,7 +249,7 @@ test('doStream handles error in stream part', async () => {
}
// Should have switched to model 2 and gotten success
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Success')).toBe(true);
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Success')).toBe(true);
expect(encounteredErrors).toHaveLength(1);
expect(encounteredErrors[0]).toBe('Overloaded');
expect(fallback.currentModelIndex).toBe(1);
@ -276,7 +270,7 @@ test('doGenerate switches models on error', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Response from fallback model' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -289,8 +283,7 @@ test('doGenerate switches models on error', async () => {
expect(fallback.modelId).toBe('failing-model');
const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' });
@ -322,7 +315,7 @@ test('cycles through all models until one works', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Success from model 3' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -332,8 +325,7 @@ test('cycles through all models until one works', async () => {
});
const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' });
@ -362,8 +354,7 @@ test('throws error when all models fail', async () => {
await expect(
fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
})
).rejects.toThrow('Model 2 capacity reached');
@ -391,8 +382,7 @@ test('model reset interval resets to first model', async () => {
// Trigger reset check
await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(fallback.currentModelIndex).toBe(0);
@ -419,7 +409,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Should not reach here' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -432,8 +422,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
// Should not retry because error doesn't match
await expect(
fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
})
).rejects.toThrow('specific-error-that-should-not-retry');
@ -460,7 +449,7 @@ test('handles non-existent model error', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Fallback response' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -475,8 +464,7 @@ test('handles non-existent model error', async () => {
});
const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' });
@ -502,7 +490,7 @@ test('handles API key errors', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Success with correct key' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -512,8 +500,7 @@ test('handles API key errors', async () => {
});
const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' });
@ -538,7 +525,7 @@ test('handles rate limit errors', async () => {
doGenerate: async () => ({
content: [{ type: 'text', text: 'Response from available model' }],
finishReason: 'stop' as const,
usage: { inputTokens: 10, outputTokens: 5 },
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
warnings: [],
}),
});
@ -548,8 +535,7 @@ test('handles rate limit errors', async () => {
});
const result = await fallback.doGenerate({
prompt: { system: 'test', messages: [] },
mode: { type: 'regular' },
prompt: [],
});
expect(attempts).toBe(1);

View File

@ -1,254 +1,245 @@
import type {
LanguageModelV2,
LanguageModelV2CallOptions,
LanguageModelV2CallWarning,
LanguageModelV2Content,
LanguageModelV2FinishReason,
LanguageModelV2StreamPart,
LanguageModelV2Usage,
SharedV2ProviderMetadata,
} from "@ai-sdk/provider";
LanguageModelV2,
LanguageModelV2CallOptions,
LanguageModelV2CallWarning,
LanguageModelV2Content,
LanguageModelV2FinishReason,
LanguageModelV2StreamPart,
LanguageModelV2Usage,
SharedV2ProviderMetadata,
} from '@ai-sdk/provider';
interface RetryableError extends Error {
statusCode?: number;
}
interface Settings {
models: LanguageModelV2[];
retryAfterOutput?: boolean;
modelResetInterval?: number;
shouldRetryThisError?: (error: Error) => boolean;
onError?: (error: Error, modelId: string) => void | Promise<void>;
models: LanguageModelV2[];
retryAfterOutput?: boolean;
modelResetInterval?: number;
shouldRetryThisError?: (error: RetryableError) => boolean;
onError?: (error: RetryableError, modelId: string) => void | Promise<void>;
}
export function createFallback(settings: Settings): FallbackModel {
return new FallbackModel(settings);
return new FallbackModel(settings);
}
const retryableStatusCodes = [
401, // wrong API key
403, // permission error, like cannot access model or from a non accessible region
408, // request timeout
409, // conflict
413, // payload too large
429, // too many requests/rate limits
500, // server error (and above)
401, // wrong API key
403, // permission error, like cannot access model or from a non accessible region
408, // request timeout
409, // conflict
413, // payload too large
429, // too many requests/rate limits
500, // server error (and above)
];
// Common error messages/codes that indicate server overload or temporary issues
const retryableErrors = [
"overloaded",
"service unavailable",
"bad gateway",
"too many requests",
"internal server error",
"gateway timeout",
"rate_limit",
"wrong-key",
"unexpected",
"capacity",
"timeout",
"server_error",
"429", // Too Many Requests
"500", // Internal Server Error
"502", // Bad Gateway
"503", // Service Unavailable
"504", // Gateway Timeout
'overloaded',
'service unavailable',
'bad gateway',
'too many requests',
'internal server error',
'gateway timeout',
'rate_limit',
'wrong-key',
'unexpected',
'capacity',
'timeout',
'server_error',
'429', // Too Many Requests
'500', // Internal Server Error
'502', // Bad Gateway
'503', // Service Unavailable
'504', // Gateway Timeout
];
function defaultShouldRetryThisError(error: any): boolean {
const statusCode = error?.statusCode;
function defaultShouldRetryThisError(error: RetryableError): boolean {
const statusCode = error?.statusCode;
if (
statusCode &&
(retryableStatusCodes.includes(statusCode) || statusCode > 500)
) {
return true;
}
if (statusCode && (retryableStatusCodes.includes(statusCode) || statusCode > 500)) {
return true;
}
if (error?.message) {
const errorString = error.message.toLowerCase() || "";
return retryableErrors.some((errType) => errorString.includes(errType));
}
if (error && typeof error === "object") {
const errorString = JSON.stringify(error).toLowerCase() || "";
return retryableErrors.some((errType) => errorString.includes(errType));
}
return false;
if (error?.message) {
const errorString = error.message.toLowerCase() || '';
return retryableErrors.some((errType) => errorString.includes(errType));
}
if (error && typeof error === 'object') {
const errorString = JSON.stringify(error).toLowerCase() || '';
return retryableErrors.some((errType) => errorString.includes(errType));
}
return false;
}
export class FallbackModel implements LanguageModelV2 {
readonly specificationVersion = "v2";
readonly specificationVersion = 'v2';
get supportedUrls():
| Record<string, RegExp[]>
| PromiseLike<Record<string, RegExp[]>> {
return this.getCurrentModel().supportedUrls;
}
get supportedUrls(): Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>> {
return this.getCurrentModel().supportedUrls;
}
get modelId(): string {
return this.getCurrentModel().modelId;
}
readonly settings: Settings;
get modelId(): string {
return this.getCurrentModel().modelId;
}
readonly settings: Settings;
currentModelIndex = 0;
private lastModelReset: number = Date.now();
private readonly modelResetInterval: number;
retryAfterOutput: boolean;
constructor(settings: Settings) {
this.settings = settings;
this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms
this.retryAfterOutput = settings.retryAfterOutput ?? true;
currentModelIndex = 0;
private lastModelReset: number = Date.now();
private readonly modelResetInterval: number;
retryAfterOutput: boolean;
constructor(settings: Settings) {
this.settings = settings;
this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms
this.retryAfterOutput = settings.retryAfterOutput ?? true;
if (!this.settings.models[this.currentModelIndex]) {
throw new Error("No models available in settings");
}
}
if (!this.settings.models[this.currentModelIndex]) {
throw new Error('No models available in settings');
}
}
get provider(): string {
return this.getCurrentModel().provider;
}
get provider(): string {
return this.getCurrentModel().provider;
}
private getCurrentModel(): LanguageModelV2 {
const model = this.settings.models[this.currentModelIndex];
if (!model) {
throw new Error(`No model available at index ${this.currentModelIndex}`);
}
return model;
}
private getCurrentModel(): LanguageModelV2 {
const model = this.settings.models[this.currentModelIndex];
if (!model) {
throw new Error(`No model available at index ${this.currentModelIndex}`);
}
return model;
}
private checkAndResetModel() {
const now = Date.now();
if (
now - this.lastModelReset >= this.modelResetInterval &&
this.currentModelIndex !== 0
) {
this.currentModelIndex = 0;
this.lastModelReset = now;
}
}
private checkAndResetModel() {
const now = Date.now();
if (now - this.lastModelReset >= this.modelResetInterval && this.currentModelIndex !== 0) {
this.currentModelIndex = 0;
this.lastModelReset = now;
}
}
private switchToNextModel() {
this.currentModelIndex =
(this.currentModelIndex + 1) % this.settings.models.length;
}
private switchToNextModel() {
this.currentModelIndex = (this.currentModelIndex + 1) % this.settings.models.length;
}
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
let lastError: Error | undefined;
const initialModel = this.currentModelIndex;
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
let lastError: RetryableError | undefined;
const initialModel = this.currentModelIndex;
do {
try {
return await fn();
} catch (error) {
lastError = error as Error;
// Only retry if it's a server/capacity error
const shouldRetry =
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
if (!shouldRetry(lastError)) {
throw lastError;
}
do {
try {
return await fn();
} catch (error) {
lastError = error as RetryableError;
// Only retry if it's a server/capacity error
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
if (!shouldRetry(lastError)) {
throw lastError;
}
if (this.settings.onError) {
await this.settings.onError(lastError, this.modelId);
}
this.switchToNextModel();
if (this.settings.onError) {
await this.settings.onError(lastError, this.modelId);
}
this.switchToNextModel();
// If we've tried all models, throw the last error
if (this.currentModelIndex === initialModel) {
throw lastError;
}
}
} while (true);
}
// If we've tried all models, throw the last error
if (this.currentModelIndex === initialModel) {
throw lastError;
}
}
} while (this.currentModelIndex !== initialModel);
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
content: LanguageModelV2Content[];
finishReason: LanguageModelV2FinishReason;
usage: LanguageModelV2Usage;
providerMetadata?: SharedV2ProviderMetadata;
request?: { body?: unknown };
response?: {
headers?: Record<string, string>;
id?: string;
timestamp?: Date;
modelId?: string;
};
warnings: LanguageModelV2CallWarning[];
}> {
this.checkAndResetModel();
return this.retry(() => this.getCurrentModel().doGenerate(options));
}
// This should never be reached, but TypeScript requires it
throw lastError || new Error('Retry failed');
}
doStream(options: LanguageModelV2CallOptions): PromiseLike<{
stream: ReadableStream<LanguageModelV2StreamPart>;
request?: { body?: unknown };
response?: { headers?: Record<string, string> };
}> {
this.checkAndResetModel();
const self = this;
const shouldRetry =
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
return this.retry(async () => {
const result = await self.getCurrentModel().doStream(options);
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
content: LanguageModelV2Content[];
finishReason: LanguageModelV2FinishReason;
usage: LanguageModelV2Usage;
providerMetadata?: SharedV2ProviderMetadata;
request?: { body?: unknown };
response?: {
headers?: Record<string, string>;
id?: string;
timestamp?: Date;
modelId?: string;
};
warnings: LanguageModelV2CallWarning[];
}> {
this.checkAndResetModel();
return this.retry(() => this.getCurrentModel().doGenerate(options));
}
let hasStreamedAny = false;
// Wrap the stream to handle errors and switch providers if needed
const wrappedStream = new ReadableStream<LanguageModelV2StreamPart>({
async start(controller) {
try {
const reader = result.stream.getReader();
doStream(options: LanguageModelV2CallOptions): PromiseLike<{
stream: ReadableStream<LanguageModelV2StreamPart>;
request?: { body?: unknown };
response?: { headers?: Record<string, string> };
}> {
this.checkAndResetModel();
const self = this;
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
return this.retry(async () => {
const result = await self.getCurrentModel().doStream(options);
while (true) {
const result = await reader.read();
let hasStreamedAny = false;
// Wrap the stream to handle errors and switch providers if needed
const wrappedStream = new ReadableStream<LanguageModelV2StreamPart>({
async start(controller) {
try {
const reader = result.stream.getReader();
const { done, value } = result;
if (
!hasStreamedAny &&
value &&
typeof value === "object" &&
"error" in value
) {
const error = value.error as any;
if (shouldRetry(error)) {
throw error;
}
}
while (true) {
const result = await reader.read();
if (done) break;
controller.enqueue(value);
const { done, value } = result;
if (!hasStreamedAny && value && typeof value === 'object' && 'error' in value) {
const error = value.error as RetryableError;
if (shouldRetry(error)) {
throw error;
}
}
if (value?.type !== "stream-start") {
hasStreamedAny = true;
}
}
controller.close();
} catch (error) {
if (self.settings.onError) {
await self.settings.onError(error as Error, self.modelId);
}
if (!hasStreamedAny || self.retryAfterOutput) {
// If nothing was streamed yet, switch models and retry
self.switchToNextModel();
try {
const nextResult = await self.doStream(options);
const nextReader = nextResult.stream.getReader();
while (true) {
const { done, value } = await nextReader.read();
if (done) break;
controller.enqueue(value);
}
controller.close();
} catch (nextError) {
controller.error(nextError);
}
return;
}
controller.error(error);
}
},
});
if (done) break;
controller.enqueue(value);
return {
stream: wrappedStream,
...(result.request && { request: result.request }),
...(result.response && { response: result.response }),
};
});
}
if (value?.type !== 'stream-start') {
hasStreamedAny = true;
}
}
controller.close();
} catch (error) {
if (self.settings.onError) {
await self.settings.onError(error as RetryableError, self.modelId);
}
if (!hasStreamedAny || self.retryAfterOutput) {
// If nothing was streamed yet, switch models and retry
self.switchToNextModel();
try {
const nextResult = await self.doStream(options);
const nextReader = nextResult.stream.getReader();
while (true) {
const { done, value } = await nextReader.read();
if (done) break;
controller.enqueue(value);
}
controller.close();
} catch (nextError) {
controller.error(nextError);
}
return;
}
controller.error(error);
}
},
});
return {
stream: wrappedStream,
...(result.request && { request: result.request }),
...(result.response && { response: result.response }),
};
});
}
}