mirror of https://github.com/buster-so/buster.git
ai fallback on ai sdk v5
This commit is contained in:
parent
fcbe1838a1
commit
a2d90f01a7
|
@ -19,11 +19,11 @@ test('doStream switches models on error', async () => {
|
||||||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||||
start(controller) {
|
start(controller) {
|
||||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||||
controller.enqueue({ type: 'text-delta', textDelta: 'Hello from fallback' });
|
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish',
|
type: 'finish',
|
||||||
finishReason: 'stop',
|
finishReason: 'stop',
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
});
|
});
|
||||||
controller.close();
|
controller.close();
|
||||||
},
|
},
|
||||||
|
@ -38,8 +38,7 @@ test('doStream switches models on error', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doStream({
|
const result = await fallback.doStream({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Read the stream
|
// Read the stream
|
||||||
|
@ -53,9 +52,9 @@ test('doStream switches models on error', async () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks).toHaveLength(3);
|
expect(chunks).toHaveLength(3);
|
||||||
expect(chunks[0].type).toBe('stream-start');
|
expect(chunks[0]!.type).toBe('stream-start');
|
||||||
expect(chunks[1]).toEqual({ type: 'text-delta', textDelta: 'Hello from fallback' });
|
expect(chunks[1]).toEqual({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
|
||||||
expect(chunks[2].type).toBe('finish');
|
expect(chunks[2]!.type).toBe('finish');
|
||||||
|
|
||||||
expect(fallback.currentModelIndex).toBe(1);
|
expect(fallback.currentModelIndex).toBe(1);
|
||||||
expect(onError).toHaveBeenCalledWith(
|
expect(onError).toHaveBeenCalledWith(
|
||||||
|
@ -85,11 +84,11 @@ test('doStream handles error during streaming', async () => {
|
||||||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||||
start(controller) {
|
start(controller) {
|
||||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||||
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback response' });
|
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback response' });
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish',
|
type: 'finish',
|
||||||
finishReason: 'stop',
|
finishReason: 'stop',
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
});
|
});
|
||||||
controller.close();
|
controller.close();
|
||||||
},
|
},
|
||||||
|
@ -104,8 +103,7 @@ test('doStream handles error during streaming', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doStream({
|
const result = await fallback.doStream({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Read the stream
|
// Read the stream
|
||||||
|
@ -118,9 +116,7 @@ test('doStream handles error during streaming', async () => {
|
||||||
chunks.push(value);
|
chunks.push(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Fallback response')).toBe(
|
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Fallback response')).toBe(true);
|
||||||
true
|
|
||||||
);
|
|
||||||
expect(fallback.currentModelIndex).toBe(1);
|
expect(fallback.currentModelIndex).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -133,7 +129,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
||||||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||||
start(controller) {
|
start(controller) {
|
||||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||||
controller.enqueue({ type: 'text-delta', textDelta: 'Partial output' });
|
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Partial output' });
|
||||||
controller.error(new Error('Stream interrupted'));
|
controller.error(new Error('Stream interrupted'));
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
@ -147,11 +143,11 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
||||||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||||
start(controller) {
|
start(controller) {
|
||||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||||
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback continuation' });
|
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback continuation' });
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish',
|
type: 'finish',
|
||||||
finishReason: 'stop',
|
finishReason: 'stop',
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
});
|
});
|
||||||
controller.close();
|
controller.close();
|
||||||
},
|
},
|
||||||
|
@ -167,8 +163,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doStream({
|
const result = await fallback.doStream({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Read the stream
|
// Read the stream
|
||||||
|
@ -185,7 +180,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
||||||
// The partial output from the failed model is lost
|
// The partial output from the failed model is lost
|
||||||
const textChunks = chunks.filter((c) => c.type === 'text-delta');
|
const textChunks = chunks.filter((c) => c.type === 'text-delta');
|
||||||
expect(textChunks).toHaveLength(1);
|
expect(textChunks).toHaveLength(1);
|
||||||
expect(textChunks[0].textDelta).toBe('Fallback continuation');
|
expect(textChunks[0]!.delta).toBe('Fallback continuation');
|
||||||
|
|
||||||
// Should switch models after error
|
// Should switch models after error
|
||||||
expect(fallback.currentModelIndex).toBe(1);
|
expect(fallback.currentModelIndex).toBe(1);
|
||||||
|
@ -218,11 +213,11 @@ test('doStream handles error in stream part', async () => {
|
||||||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||||
start(controller) {
|
start(controller) {
|
||||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||||
controller.enqueue({ type: 'text-delta', textDelta: 'Success' });
|
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Success' });
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish',
|
type: 'finish',
|
||||||
finishReason: 'stop',
|
finishReason: 'stop',
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
});
|
});
|
||||||
controller.close();
|
controller.close();
|
||||||
},
|
},
|
||||||
|
@ -240,8 +235,7 @@ test('doStream handles error in stream part', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doStream({
|
const result = await fallback.doStream({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Read the stream
|
// Read the stream
|
||||||
|
@ -255,7 +249,7 @@ test('doStream handles error in stream part', async () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have switched to model 2 and gotten success
|
// Should have switched to model 2 and gotten success
|
||||||
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Success')).toBe(true);
|
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Success')).toBe(true);
|
||||||
expect(encounteredErrors).toHaveLength(1);
|
expect(encounteredErrors).toHaveLength(1);
|
||||||
expect(encounteredErrors[0]).toBe('Overloaded');
|
expect(encounteredErrors[0]).toBe('Overloaded');
|
||||||
expect(fallback.currentModelIndex).toBe(1);
|
expect(fallback.currentModelIndex).toBe(1);
|
||||||
|
@ -276,7 +270,7 @@ test('doGenerate switches models on error', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Response from fallback model' }],
|
content: [{ type: 'text', text: 'Response from fallback model' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -289,8 +283,7 @@ test('doGenerate switches models on error', async () => {
|
||||||
expect(fallback.modelId).toBe('failing-model');
|
expect(fallback.modelId).toBe('failing-model');
|
||||||
|
|
||||||
const result = await fallback.doGenerate({
|
const result = await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' });
|
expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' });
|
||||||
|
@ -322,7 +315,7 @@ test('cycles through all models until one works', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Success from model 3' }],
|
content: [{ type: 'text', text: 'Success from model 3' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -332,8 +325,7 @@ test('cycles through all models until one works', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doGenerate({
|
const result = await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' });
|
expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' });
|
||||||
|
@ -362,8 +354,7 @@ test('throws error when all models fail', async () => {
|
||||||
|
|
||||||
await expect(
|
await expect(
|
||||||
fallback.doGenerate({
|
fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
})
|
})
|
||||||
).rejects.toThrow('Model 2 capacity reached');
|
).rejects.toThrow('Model 2 capacity reached');
|
||||||
|
|
||||||
|
@ -391,8 +382,7 @@ test('model reset interval resets to first model', async () => {
|
||||||
|
|
||||||
// Trigger reset check
|
// Trigger reset check
|
||||||
await fallback.doGenerate({
|
await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(fallback.currentModelIndex).toBe(0);
|
expect(fallback.currentModelIndex).toBe(0);
|
||||||
|
@ -419,7 +409,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Should not reach here' }],
|
content: [{ type: 'text', text: 'Should not reach here' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -432,8 +422,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
|
||||||
// Should not retry because error doesn't match
|
// Should not retry because error doesn't match
|
||||||
await expect(
|
await expect(
|
||||||
fallback.doGenerate({
|
fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
})
|
})
|
||||||
).rejects.toThrow('specific-error-that-should-not-retry');
|
).rejects.toThrow('specific-error-that-should-not-retry');
|
||||||
|
|
||||||
|
@ -460,7 +449,7 @@ test('handles non-existent model error', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Fallback response' }],
|
content: [{ type: 'text', text: 'Fallback response' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -475,8 +464,7 @@ test('handles non-existent model error', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doGenerate({
|
const result = await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' });
|
expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' });
|
||||||
|
@ -502,7 +490,7 @@ test('handles API key errors', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Success with correct key' }],
|
content: [{ type: 'text', text: 'Success with correct key' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -512,8 +500,7 @@ test('handles API key errors', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doGenerate({
|
const result = await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' });
|
expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' });
|
||||||
|
@ -538,7 +525,7 @@ test('handles rate limit errors', async () => {
|
||||||
doGenerate: async () => ({
|
doGenerate: async () => ({
|
||||||
content: [{ type: 'text', text: 'Response from available model' }],
|
content: [{ type: 'text', text: 'Response from available model' }],
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: { inputTokens: 10, outputTokens: 5 },
|
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||||
warnings: [],
|
warnings: [],
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -548,8 +535,7 @@ test('handles rate limit errors', async () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await fallback.doGenerate({
|
const result = await fallback.doGenerate({
|
||||||
prompt: { system: 'test', messages: [] },
|
prompt: [],
|
||||||
mode: { type: 'regular' },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(attempts).toBe(1);
|
expect(attempts).toBe(1);
|
||||||
|
|
|
@ -7,14 +7,18 @@ import type {
|
||||||
LanguageModelV2StreamPart,
|
LanguageModelV2StreamPart,
|
||||||
LanguageModelV2Usage,
|
LanguageModelV2Usage,
|
||||||
SharedV2ProviderMetadata,
|
SharedV2ProviderMetadata,
|
||||||
} from "@ai-sdk/provider";
|
} from '@ai-sdk/provider';
|
||||||
|
|
||||||
|
interface RetryableError extends Error {
|
||||||
|
statusCode?: number;
|
||||||
|
}
|
||||||
|
|
||||||
interface Settings {
|
interface Settings {
|
||||||
models: LanguageModelV2[];
|
models: LanguageModelV2[];
|
||||||
retryAfterOutput?: boolean;
|
retryAfterOutput?: boolean;
|
||||||
modelResetInterval?: number;
|
modelResetInterval?: number;
|
||||||
shouldRetryThisError?: (error: Error) => boolean;
|
shouldRetryThisError?: (error: RetryableError) => boolean;
|
||||||
onError?: (error: Error, modelId: string) => void | Promise<void>;
|
onError?: (error: RetryableError, modelId: string) => void | Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createFallback(settings: Settings): FallbackModel {
|
export function createFallback(settings: Settings): FallbackModel {
|
||||||
|
@ -32,52 +36,47 @@ const retryableStatusCodes = [
|
||||||
];
|
];
|
||||||
// Common error messages/codes that indicate server overload or temporary issues
|
// Common error messages/codes that indicate server overload or temporary issues
|
||||||
const retryableErrors = [
|
const retryableErrors = [
|
||||||
"overloaded",
|
'overloaded',
|
||||||
"service unavailable",
|
'service unavailable',
|
||||||
"bad gateway",
|
'bad gateway',
|
||||||
"too many requests",
|
'too many requests',
|
||||||
"internal server error",
|
'internal server error',
|
||||||
"gateway timeout",
|
'gateway timeout',
|
||||||
"rate_limit",
|
'rate_limit',
|
||||||
"wrong-key",
|
'wrong-key',
|
||||||
"unexpected",
|
'unexpected',
|
||||||
"capacity",
|
'capacity',
|
||||||
"timeout",
|
'timeout',
|
||||||
"server_error",
|
'server_error',
|
||||||
"429", // Too Many Requests
|
'429', // Too Many Requests
|
||||||
"500", // Internal Server Error
|
'500', // Internal Server Error
|
||||||
"502", // Bad Gateway
|
'502', // Bad Gateway
|
||||||
"503", // Service Unavailable
|
'503', // Service Unavailable
|
||||||
"504", // Gateway Timeout
|
'504', // Gateway Timeout
|
||||||
];
|
];
|
||||||
|
|
||||||
function defaultShouldRetryThisError(error: any): boolean {
|
function defaultShouldRetryThisError(error: RetryableError): boolean {
|
||||||
const statusCode = error?.statusCode;
|
const statusCode = error?.statusCode;
|
||||||
|
|
||||||
if (
|
if (statusCode && (retryableStatusCodes.includes(statusCode) || statusCode > 500)) {
|
||||||
statusCode &&
|
|
||||||
(retryableStatusCodes.includes(statusCode) || statusCode > 500)
|
|
||||||
) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error?.message) {
|
if (error?.message) {
|
||||||
const errorString = error.message.toLowerCase() || "";
|
const errorString = error.message.toLowerCase() || '';
|
||||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||||
}
|
}
|
||||||
if (error && typeof error === "object") {
|
if (error && typeof error === 'object') {
|
||||||
const errorString = JSON.stringify(error).toLowerCase() || "";
|
const errorString = JSON.stringify(error).toLowerCase() || '';
|
||||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class FallbackModel implements LanguageModelV2 {
|
export class FallbackModel implements LanguageModelV2 {
|
||||||
readonly specificationVersion = "v2";
|
readonly specificationVersion = 'v2';
|
||||||
|
|
||||||
get supportedUrls():
|
get supportedUrls(): Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>> {
|
||||||
| Record<string, RegExp[]>
|
|
||||||
| PromiseLike<Record<string, RegExp[]>> {
|
|
||||||
return this.getCurrentModel().supportedUrls;
|
return this.getCurrentModel().supportedUrls;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
this.retryAfterOutput = settings.retryAfterOutput ?? true;
|
this.retryAfterOutput = settings.retryAfterOutput ?? true;
|
||||||
|
|
||||||
if (!this.settings.models[this.currentModelIndex]) {
|
if (!this.settings.models[this.currentModelIndex]) {
|
||||||
throw new Error("No models available in settings");
|
throw new Error('No models available in settings');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,32 +113,27 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
|
|
||||||
private checkAndResetModel() {
|
private checkAndResetModel() {
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
if (
|
if (now - this.lastModelReset >= this.modelResetInterval && this.currentModelIndex !== 0) {
|
||||||
now - this.lastModelReset >= this.modelResetInterval &&
|
|
||||||
this.currentModelIndex !== 0
|
|
||||||
) {
|
|
||||||
this.currentModelIndex = 0;
|
this.currentModelIndex = 0;
|
||||||
this.lastModelReset = now;
|
this.lastModelReset = now;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private switchToNextModel() {
|
private switchToNextModel() {
|
||||||
this.currentModelIndex =
|
this.currentModelIndex = (this.currentModelIndex + 1) % this.settings.models.length;
|
||||||
(this.currentModelIndex + 1) % this.settings.models.length;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
|
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
|
||||||
let lastError: Error | undefined;
|
let lastError: RetryableError | undefined;
|
||||||
const initialModel = this.currentModelIndex;
|
const initialModel = this.currentModelIndex;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
try {
|
try {
|
||||||
return await fn();
|
return await fn();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
lastError = error as Error;
|
lastError = error as RetryableError;
|
||||||
// Only retry if it's a server/capacity error
|
// Only retry if it's a server/capacity error
|
||||||
const shouldRetry =
|
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||||
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
|
||||||
if (!shouldRetry(lastError)) {
|
if (!shouldRetry(lastError)) {
|
||||||
throw lastError;
|
throw lastError;
|
||||||
}
|
}
|
||||||
|
@ -154,7 +148,10 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
throw lastError;
|
throw lastError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} while (true);
|
} while (this.currentModelIndex !== initialModel);
|
||||||
|
|
||||||
|
// This should never be reached, but TypeScript requires it
|
||||||
|
throw lastError || new Error('Retry failed');
|
||||||
}
|
}
|
||||||
|
|
||||||
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
|
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
|
||||||
|
@ -182,8 +179,7 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
}> {
|
}> {
|
||||||
this.checkAndResetModel();
|
this.checkAndResetModel();
|
||||||
const self = this;
|
const self = this;
|
||||||
const shouldRetry =
|
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||||
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
|
||||||
return this.retry(async () => {
|
return this.retry(async () => {
|
||||||
const result = await self.getCurrentModel().doStream(options);
|
const result = await self.getCurrentModel().doStream(options);
|
||||||
|
|
||||||
|
@ -198,13 +194,8 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
const result = await reader.read();
|
const result = await reader.read();
|
||||||
|
|
||||||
const { done, value } = result;
|
const { done, value } = result;
|
||||||
if (
|
if (!hasStreamedAny && value && typeof value === 'object' && 'error' in value) {
|
||||||
!hasStreamedAny &&
|
const error = value.error as RetryableError;
|
||||||
value &&
|
|
||||||
typeof value === "object" &&
|
|
||||||
"error" in value
|
|
||||||
) {
|
|
||||||
const error = value.error as any;
|
|
||||||
if (shouldRetry(error)) {
|
if (shouldRetry(error)) {
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
@ -213,14 +204,14 @@ export class FallbackModel implements LanguageModelV2 {
|
||||||
if (done) break;
|
if (done) break;
|
||||||
controller.enqueue(value);
|
controller.enqueue(value);
|
||||||
|
|
||||||
if (value?.type !== "stream-start") {
|
if (value?.type !== 'stream-start') {
|
||||||
hasStreamedAny = true;
|
hasStreamedAny = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
controller.close();
|
controller.close();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (self.settings.onError) {
|
if (self.settings.onError) {
|
||||||
await self.settings.onError(error as Error, self.modelId);
|
await self.settings.onError(error as RetryableError, self.modelId);
|
||||||
}
|
}
|
||||||
if (!hasStreamedAny || self.retryAfterOutput) {
|
if (!hasStreamedAny || self.retryAfterOutput) {
|
||||||
// If nothing was streamed yet, switch models and retry
|
// If nothing was streamed yet, switch models and retry
|
||||||
|
|
Loading…
Reference in New Issue