mirror of https://github.com/buster-so/buster.git
ai fallback on ai sdk v5
This commit is contained in:
parent
fcbe1838a1
commit
a2d90f01a7
|
@ -19,11 +19,11 @@ test('doStream switches models on error', async () => {
|
|||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||
start(controller) {
|
||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||
controller.enqueue({ type: 'text-delta', textDelta: 'Hello from fallback' });
|
||||
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
|
||||
controller.enqueue({
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
});
|
||||
controller.close();
|
||||
},
|
||||
|
@ -38,8 +38,7 @@ test('doStream switches models on error', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doStream({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
// Read the stream
|
||||
|
@ -53,9 +52,9 @@ test('doStream switches models on error', async () => {
|
|||
}
|
||||
|
||||
expect(chunks).toHaveLength(3);
|
||||
expect(chunks[0].type).toBe('stream-start');
|
||||
expect(chunks[1]).toEqual({ type: 'text-delta', textDelta: 'Hello from fallback' });
|
||||
expect(chunks[2].type).toBe('finish');
|
||||
expect(chunks[0]!.type).toBe('stream-start');
|
||||
expect(chunks[1]).toEqual({ type: 'text-delta', id: '1', delta: 'Hello from fallback' });
|
||||
expect(chunks[2]!.type).toBe('finish');
|
||||
|
||||
expect(fallback.currentModelIndex).toBe(1);
|
||||
expect(onError).toHaveBeenCalledWith(
|
||||
|
@ -85,11 +84,11 @@ test('doStream handles error during streaming', async () => {
|
|||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||
start(controller) {
|
||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback response' });
|
||||
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback response' });
|
||||
controller.enqueue({
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
});
|
||||
controller.close();
|
||||
},
|
||||
|
@ -104,8 +103,7 @@ test('doStream handles error during streaming', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doStream({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
// Read the stream
|
||||
|
@ -118,9 +116,7 @@ test('doStream handles error during streaming', async () => {
|
|||
chunks.push(value);
|
||||
}
|
||||
|
||||
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Fallback response')).toBe(
|
||||
true
|
||||
);
|
||||
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Fallback response')).toBe(true);
|
||||
expect(fallback.currentModelIndex).toBe(1);
|
||||
});
|
||||
|
||||
|
@ -133,7 +129,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
|||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||
start(controller) {
|
||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||
controller.enqueue({ type: 'text-delta', textDelta: 'Partial output' });
|
||||
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Partial output' });
|
||||
controller.error(new Error('Stream interrupted'));
|
||||
},
|
||||
}),
|
||||
|
@ -147,11 +143,11 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
|||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||
start(controller) {
|
||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||
controller.enqueue({ type: 'text-delta', textDelta: 'Fallback continuation' });
|
||||
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Fallback continuation' });
|
||||
controller.enqueue({
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
});
|
||||
controller.close();
|
||||
},
|
||||
|
@ -167,8 +163,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doStream({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
// Read the stream
|
||||
|
@ -185,7 +180,7 @@ test('doStream with partial output and retryAfterOutput=true', async () => {
|
|||
// The partial output from the failed model is lost
|
||||
const textChunks = chunks.filter((c) => c.type === 'text-delta');
|
||||
expect(textChunks).toHaveLength(1);
|
||||
expect(textChunks[0].textDelta).toBe('Fallback continuation');
|
||||
expect(textChunks[0]!.delta).toBe('Fallback continuation');
|
||||
|
||||
// Should switch models after error
|
||||
expect(fallback.currentModelIndex).toBe(1);
|
||||
|
@ -218,11 +213,11 @@ test('doStream handles error in stream part', async () => {
|
|||
stream: new ReadableStream<LanguageModelV2StreamPart>({
|
||||
start(controller) {
|
||||
controller.enqueue({ type: 'stream-start', warnings: [] });
|
||||
controller.enqueue({ type: 'text-delta', textDelta: 'Success' });
|
||||
controller.enqueue({ type: 'text-delta', id: '1', delta: 'Success' });
|
||||
controller.enqueue({
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
});
|
||||
controller.close();
|
||||
},
|
||||
|
@ -240,8 +235,7 @@ test('doStream handles error in stream part', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doStream({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
// Read the stream
|
||||
|
@ -255,7 +249,7 @@ test('doStream handles error in stream part', async () => {
|
|||
}
|
||||
|
||||
// Should have switched to model 2 and gotten success
|
||||
expect(chunks.some((c) => c.type === 'text-delta' && c.textDelta === 'Success')).toBe(true);
|
||||
expect(chunks.some((c) => c.type === 'text-delta' && c.delta === 'Success')).toBe(true);
|
||||
expect(encounteredErrors).toHaveLength(1);
|
||||
expect(encounteredErrors[0]).toBe('Overloaded');
|
||||
expect(fallback.currentModelIndex).toBe(1);
|
||||
|
@ -276,7 +270,7 @@ test('doGenerate switches models on error', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Response from fallback model' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -289,8 +283,7 @@ test('doGenerate switches models on error', async () => {
|
|||
expect(fallback.modelId).toBe('failing-model');
|
||||
|
||||
const result = await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Response from fallback model' });
|
||||
|
@ -322,7 +315,7 @@ test('cycles through all models until one works', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Success from model 3' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -332,8 +325,7 @@ test('cycles through all models until one works', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Success from model 3' });
|
||||
|
@ -362,8 +354,7 @@ test('throws error when all models fail', async () => {
|
|||
|
||||
await expect(
|
||||
fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
})
|
||||
).rejects.toThrow('Model 2 capacity reached');
|
||||
|
||||
|
@ -391,8 +382,7 @@ test('model reset interval resets to first model', async () => {
|
|||
|
||||
// Trigger reset check
|
||||
await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(fallback.currentModelIndex).toBe(0);
|
||||
|
@ -419,7 +409,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Should not reach here' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -432,8 +422,7 @@ test('shouldRetryThisError callback controls retry behavior', async () => {
|
|||
// Should not retry because error doesn't match
|
||||
await expect(
|
||||
fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
})
|
||||
).rejects.toThrow('specific-error-that-should-not-retry');
|
||||
|
||||
|
@ -460,7 +449,7 @@ test('handles non-existent model error', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Fallback response' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -475,8 +464,7 @@ test('handles non-existent model error', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Fallback response' });
|
||||
|
@ -502,7 +490,7 @@ test('handles API key errors', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Success with correct key' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -512,8 +500,7 @@ test('handles API key errors', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(result.content[0]).toEqual({ type: 'text', text: 'Success with correct key' });
|
||||
|
@ -538,7 +525,7 @@ test('handles rate limit errors', async () => {
|
|||
doGenerate: async () => ({
|
||||
content: [{ type: 'text', text: 'Response from available model' }],
|
||||
finishReason: 'stop' as const,
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
|
||||
warnings: [],
|
||||
}),
|
||||
});
|
||||
|
@ -548,8 +535,7 @@ test('handles rate limit errors', async () => {
|
|||
});
|
||||
|
||||
const result = await fallback.doGenerate({
|
||||
prompt: { system: 'test', messages: [] },
|
||||
mode: { type: 'regular' },
|
||||
prompt: [],
|
||||
});
|
||||
|
||||
expect(attempts).toBe(1);
|
||||
|
|
|
@ -1,254 +1,245 @@
|
|||
import type {
|
||||
LanguageModelV2,
|
||||
LanguageModelV2CallOptions,
|
||||
LanguageModelV2CallWarning,
|
||||
LanguageModelV2Content,
|
||||
LanguageModelV2FinishReason,
|
||||
LanguageModelV2StreamPart,
|
||||
LanguageModelV2Usage,
|
||||
SharedV2ProviderMetadata,
|
||||
} from "@ai-sdk/provider";
|
||||
LanguageModelV2,
|
||||
LanguageModelV2CallOptions,
|
||||
LanguageModelV2CallWarning,
|
||||
LanguageModelV2Content,
|
||||
LanguageModelV2FinishReason,
|
||||
LanguageModelV2StreamPart,
|
||||
LanguageModelV2Usage,
|
||||
SharedV2ProviderMetadata,
|
||||
} from '@ai-sdk/provider';
|
||||
|
||||
interface RetryableError extends Error {
|
||||
statusCode?: number;
|
||||
}
|
||||
|
||||
interface Settings {
|
||||
models: LanguageModelV2[];
|
||||
retryAfterOutput?: boolean;
|
||||
modelResetInterval?: number;
|
||||
shouldRetryThisError?: (error: Error) => boolean;
|
||||
onError?: (error: Error, modelId: string) => void | Promise<void>;
|
||||
models: LanguageModelV2[];
|
||||
retryAfterOutput?: boolean;
|
||||
modelResetInterval?: number;
|
||||
shouldRetryThisError?: (error: RetryableError) => boolean;
|
||||
onError?: (error: RetryableError, modelId: string) => void | Promise<void>;
|
||||
}
|
||||
|
||||
export function createFallback(settings: Settings): FallbackModel {
|
||||
return new FallbackModel(settings);
|
||||
return new FallbackModel(settings);
|
||||
}
|
||||
|
||||
const retryableStatusCodes = [
|
||||
401, // wrong API key
|
||||
403, // permission error, like cannot access model or from a non accessible region
|
||||
408, // request timeout
|
||||
409, // conflict
|
||||
413, // payload too large
|
||||
429, // too many requests/rate limits
|
||||
500, // server error (and above)
|
||||
401, // wrong API key
|
||||
403, // permission error, like cannot access model or from a non accessible region
|
||||
408, // request timeout
|
||||
409, // conflict
|
||||
413, // payload too large
|
||||
429, // too many requests/rate limits
|
||||
500, // server error (and above)
|
||||
];
|
||||
// Common error messages/codes that indicate server overload or temporary issues
|
||||
const retryableErrors = [
|
||||
"overloaded",
|
||||
"service unavailable",
|
||||
"bad gateway",
|
||||
"too many requests",
|
||||
"internal server error",
|
||||
"gateway timeout",
|
||||
"rate_limit",
|
||||
"wrong-key",
|
||||
"unexpected",
|
||||
"capacity",
|
||||
"timeout",
|
||||
"server_error",
|
||||
"429", // Too Many Requests
|
||||
"500", // Internal Server Error
|
||||
"502", // Bad Gateway
|
||||
"503", // Service Unavailable
|
||||
"504", // Gateway Timeout
|
||||
'overloaded',
|
||||
'service unavailable',
|
||||
'bad gateway',
|
||||
'too many requests',
|
||||
'internal server error',
|
||||
'gateway timeout',
|
||||
'rate_limit',
|
||||
'wrong-key',
|
||||
'unexpected',
|
||||
'capacity',
|
||||
'timeout',
|
||||
'server_error',
|
||||
'429', // Too Many Requests
|
||||
'500', // Internal Server Error
|
||||
'502', // Bad Gateway
|
||||
'503', // Service Unavailable
|
||||
'504', // Gateway Timeout
|
||||
];
|
||||
|
||||
function defaultShouldRetryThisError(error: any): boolean {
|
||||
const statusCode = error?.statusCode;
|
||||
function defaultShouldRetryThisError(error: RetryableError): boolean {
|
||||
const statusCode = error?.statusCode;
|
||||
|
||||
if (
|
||||
statusCode &&
|
||||
(retryableStatusCodes.includes(statusCode) || statusCode > 500)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
if (statusCode && (retryableStatusCodes.includes(statusCode) || statusCode > 500)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (error?.message) {
|
||||
const errorString = error.message.toLowerCase() || "";
|
||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||
}
|
||||
if (error && typeof error === "object") {
|
||||
const errorString = JSON.stringify(error).toLowerCase() || "";
|
||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||
}
|
||||
return false;
|
||||
if (error?.message) {
|
||||
const errorString = error.message.toLowerCase() || '';
|
||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||
}
|
||||
if (error && typeof error === 'object') {
|
||||
const errorString = JSON.stringify(error).toLowerCase() || '';
|
||||
return retryableErrors.some((errType) => errorString.includes(errType));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export class FallbackModel implements LanguageModelV2 {
|
||||
readonly specificationVersion = "v2";
|
||||
readonly specificationVersion = 'v2';
|
||||
|
||||
get supportedUrls():
|
||||
| Record<string, RegExp[]>
|
||||
| PromiseLike<Record<string, RegExp[]>> {
|
||||
return this.getCurrentModel().supportedUrls;
|
||||
}
|
||||
get supportedUrls(): Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>> {
|
||||
return this.getCurrentModel().supportedUrls;
|
||||
}
|
||||
|
||||
get modelId(): string {
|
||||
return this.getCurrentModel().modelId;
|
||||
}
|
||||
readonly settings: Settings;
|
||||
get modelId(): string {
|
||||
return this.getCurrentModel().modelId;
|
||||
}
|
||||
readonly settings: Settings;
|
||||
|
||||
currentModelIndex = 0;
|
||||
private lastModelReset: number = Date.now();
|
||||
private readonly modelResetInterval: number;
|
||||
retryAfterOutput: boolean;
|
||||
constructor(settings: Settings) {
|
||||
this.settings = settings;
|
||||
this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms
|
||||
this.retryAfterOutput = settings.retryAfterOutput ?? true;
|
||||
currentModelIndex = 0;
|
||||
private lastModelReset: number = Date.now();
|
||||
private readonly modelResetInterval: number;
|
||||
retryAfterOutput: boolean;
|
||||
constructor(settings: Settings) {
|
||||
this.settings = settings;
|
||||
this.modelResetInterval = settings.modelResetInterval ?? 3 * 60 * 1000; // Default 3 minutes in ms
|
||||
this.retryAfterOutput = settings.retryAfterOutput ?? true;
|
||||
|
||||
if (!this.settings.models[this.currentModelIndex]) {
|
||||
throw new Error("No models available in settings");
|
||||
}
|
||||
}
|
||||
if (!this.settings.models[this.currentModelIndex]) {
|
||||
throw new Error('No models available in settings');
|
||||
}
|
||||
}
|
||||
|
||||
get provider(): string {
|
||||
return this.getCurrentModel().provider;
|
||||
}
|
||||
get provider(): string {
|
||||
return this.getCurrentModel().provider;
|
||||
}
|
||||
|
||||
private getCurrentModel(): LanguageModelV2 {
|
||||
const model = this.settings.models[this.currentModelIndex];
|
||||
if (!model) {
|
||||
throw new Error(`No model available at index ${this.currentModelIndex}`);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
private getCurrentModel(): LanguageModelV2 {
|
||||
const model = this.settings.models[this.currentModelIndex];
|
||||
if (!model) {
|
||||
throw new Error(`No model available at index ${this.currentModelIndex}`);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
private checkAndResetModel() {
|
||||
const now = Date.now();
|
||||
if (
|
||||
now - this.lastModelReset >= this.modelResetInterval &&
|
||||
this.currentModelIndex !== 0
|
||||
) {
|
||||
this.currentModelIndex = 0;
|
||||
this.lastModelReset = now;
|
||||
}
|
||||
}
|
||||
private checkAndResetModel() {
|
||||
const now = Date.now();
|
||||
if (now - this.lastModelReset >= this.modelResetInterval && this.currentModelIndex !== 0) {
|
||||
this.currentModelIndex = 0;
|
||||
this.lastModelReset = now;
|
||||
}
|
||||
}
|
||||
|
||||
private switchToNextModel() {
|
||||
this.currentModelIndex =
|
||||
(this.currentModelIndex + 1) % this.settings.models.length;
|
||||
}
|
||||
private switchToNextModel() {
|
||||
this.currentModelIndex = (this.currentModelIndex + 1) % this.settings.models.length;
|
||||
}
|
||||
|
||||
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
|
||||
let lastError: Error | undefined;
|
||||
const initialModel = this.currentModelIndex;
|
||||
private async retry<T>(fn: () => PromiseLike<T>): Promise<T> {
|
||||
let lastError: RetryableError | undefined;
|
||||
const initialModel = this.currentModelIndex;
|
||||
|
||||
do {
|
||||
try {
|
||||
return await fn();
|
||||
} catch (error) {
|
||||
lastError = error as Error;
|
||||
// Only retry if it's a server/capacity error
|
||||
const shouldRetry =
|
||||
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||
if (!shouldRetry(lastError)) {
|
||||
throw lastError;
|
||||
}
|
||||
do {
|
||||
try {
|
||||
return await fn();
|
||||
} catch (error) {
|
||||
lastError = error as RetryableError;
|
||||
// Only retry if it's a server/capacity error
|
||||
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||
if (!shouldRetry(lastError)) {
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
if (this.settings.onError) {
|
||||
await this.settings.onError(lastError, this.modelId);
|
||||
}
|
||||
this.switchToNextModel();
|
||||
if (this.settings.onError) {
|
||||
await this.settings.onError(lastError, this.modelId);
|
||||
}
|
||||
this.switchToNextModel();
|
||||
|
||||
// If we've tried all models, throw the last error
|
||||
if (this.currentModelIndex === initialModel) {
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
} while (true);
|
||||
}
|
||||
// If we've tried all models, throw the last error
|
||||
if (this.currentModelIndex === initialModel) {
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
} while (this.currentModelIndex !== initialModel);
|
||||
|
||||
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
|
||||
content: LanguageModelV2Content[];
|
||||
finishReason: LanguageModelV2FinishReason;
|
||||
usage: LanguageModelV2Usage;
|
||||
providerMetadata?: SharedV2ProviderMetadata;
|
||||
request?: { body?: unknown };
|
||||
response?: {
|
||||
headers?: Record<string, string>;
|
||||
id?: string;
|
||||
timestamp?: Date;
|
||||
modelId?: string;
|
||||
};
|
||||
warnings: LanguageModelV2CallWarning[];
|
||||
}> {
|
||||
this.checkAndResetModel();
|
||||
return this.retry(() => this.getCurrentModel().doGenerate(options));
|
||||
}
|
||||
// This should never be reached, but TypeScript requires it
|
||||
throw lastError || new Error('Retry failed');
|
||||
}
|
||||
|
||||
doStream(options: LanguageModelV2CallOptions): PromiseLike<{
|
||||
stream: ReadableStream<LanguageModelV2StreamPart>;
|
||||
request?: { body?: unknown };
|
||||
response?: { headers?: Record<string, string> };
|
||||
}> {
|
||||
this.checkAndResetModel();
|
||||
const self = this;
|
||||
const shouldRetry =
|
||||
this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||
return this.retry(async () => {
|
||||
const result = await self.getCurrentModel().doStream(options);
|
||||
doGenerate(options: LanguageModelV2CallOptions): PromiseLike<{
|
||||
content: LanguageModelV2Content[];
|
||||
finishReason: LanguageModelV2FinishReason;
|
||||
usage: LanguageModelV2Usage;
|
||||
providerMetadata?: SharedV2ProviderMetadata;
|
||||
request?: { body?: unknown };
|
||||
response?: {
|
||||
headers?: Record<string, string>;
|
||||
id?: string;
|
||||
timestamp?: Date;
|
||||
modelId?: string;
|
||||
};
|
||||
warnings: LanguageModelV2CallWarning[];
|
||||
}> {
|
||||
this.checkAndResetModel();
|
||||
return this.retry(() => this.getCurrentModel().doGenerate(options));
|
||||
}
|
||||
|
||||
let hasStreamedAny = false;
|
||||
// Wrap the stream to handle errors and switch providers if needed
|
||||
const wrappedStream = new ReadableStream<LanguageModelV2StreamPart>({
|
||||
async start(controller) {
|
||||
try {
|
||||
const reader = result.stream.getReader();
|
||||
doStream(options: LanguageModelV2CallOptions): PromiseLike<{
|
||||
stream: ReadableStream<LanguageModelV2StreamPart>;
|
||||
request?: { body?: unknown };
|
||||
response?: { headers?: Record<string, string> };
|
||||
}> {
|
||||
this.checkAndResetModel();
|
||||
const self = this;
|
||||
const shouldRetry = this.settings.shouldRetryThisError || defaultShouldRetryThisError;
|
||||
return this.retry(async () => {
|
||||
const result = await self.getCurrentModel().doStream(options);
|
||||
|
||||
while (true) {
|
||||
const result = await reader.read();
|
||||
let hasStreamedAny = false;
|
||||
// Wrap the stream to handle errors and switch providers if needed
|
||||
const wrappedStream = new ReadableStream<LanguageModelV2StreamPart>({
|
||||
async start(controller) {
|
||||
try {
|
||||
const reader = result.stream.getReader();
|
||||
|
||||
const { done, value } = result;
|
||||
if (
|
||||
!hasStreamedAny &&
|
||||
value &&
|
||||
typeof value === "object" &&
|
||||
"error" in value
|
||||
) {
|
||||
const error = value.error as any;
|
||||
if (shouldRetry(error)) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
const result = await reader.read();
|
||||
|
||||
if (done) break;
|
||||
controller.enqueue(value);
|
||||
const { done, value } = result;
|
||||
if (!hasStreamedAny && value && typeof value === 'object' && 'error' in value) {
|
||||
const error = value.error as RetryableError;
|
||||
if (shouldRetry(error)) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
if (value?.type !== "stream-start") {
|
||||
hasStreamedAny = true;
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
} catch (error) {
|
||||
if (self.settings.onError) {
|
||||
await self.settings.onError(error as Error, self.modelId);
|
||||
}
|
||||
if (!hasStreamedAny || self.retryAfterOutput) {
|
||||
// If nothing was streamed yet, switch models and retry
|
||||
self.switchToNextModel();
|
||||
try {
|
||||
const nextResult = await self.doStream(options);
|
||||
const nextReader = nextResult.stream.getReader();
|
||||
while (true) {
|
||||
const { done, value } = await nextReader.read();
|
||||
if (done) break;
|
||||
controller.enqueue(value);
|
||||
}
|
||||
controller.close();
|
||||
} catch (nextError) {
|
||||
controller.error(nextError);
|
||||
}
|
||||
return;
|
||||
}
|
||||
controller.error(error);
|
||||
}
|
||||
},
|
||||
});
|
||||
if (done) break;
|
||||
controller.enqueue(value);
|
||||
|
||||
return {
|
||||
stream: wrappedStream,
|
||||
...(result.request && { request: result.request }),
|
||||
...(result.response && { response: result.response }),
|
||||
};
|
||||
});
|
||||
}
|
||||
if (value?.type !== 'stream-start') {
|
||||
hasStreamedAny = true;
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
} catch (error) {
|
||||
if (self.settings.onError) {
|
||||
await self.settings.onError(error as RetryableError, self.modelId);
|
||||
}
|
||||
if (!hasStreamedAny || self.retryAfterOutput) {
|
||||
// If nothing was streamed yet, switch models and retry
|
||||
self.switchToNextModel();
|
||||
try {
|
||||
const nextResult = await self.doStream(options);
|
||||
const nextReader = nextResult.stream.getReader();
|
||||
while (true) {
|
||||
const { done, value } = await nextReader.read();
|
||||
if (done) break;
|
||||
controller.enqueue(value);
|
||||
}
|
||||
controller.close();
|
||||
} catch (nextError) {
|
||||
controller.error(nextError);
|
||||
}
|
||||
return;
|
||||
}
|
||||
controller.error(error);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
stream: wrappedStream,
|
||||
...(result.request && { request: result.request }),
|
||||
...(result.response && { response: result.response }),
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue