diff --git a/packages/ai/src/steps/extract-values-search-step.test.ts b/packages/ai/src/steps/extract-values-search-step.test.ts index f982a3d0a..1492fe09e 100644 --- a/packages/ai/src/steps/extract-values-search-step.test.ts +++ b/packages/ai/src/steps/extract-values-search-step.test.ts @@ -1,7 +1,9 @@ -import { RuntimeContext } from '@mastra/core/runtime-context'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import type { AnalystRuntimeContext } from '../workflows/analyst-workflow'; -import { extractValuesSearchStep } from './extract-values-search-step'; + +// Mock the AI models first, before any imports that might use them +vi.mock('../utils/models/haiku-3-5', () => ({ + Haiku35: 'mock-model', +})); // Mock the stored-values package vi.mock('@buster/stored-values/search', () => { @@ -11,26 +13,48 @@ vi.mock('@buster/stored-values/search', () => { }; }); -// Mock the AI models -vi.mock('../utils/models/sonnet-4', () => ({ - Sonnet4: 'mock-model', -})); - // Mock Braintrust vi.mock('braintrust', () => ({ wrapTraced: vi.fn((fn) => fn), wrapAISDKModel: vi.fn((model) => model), })); +// Create a ref object to hold the mock generate function +const mockGenerateRef = { current: vi.fn() }; + +// Mock the Agent class from Mastra with the generate function +vi.mock('@mastra/core', async () => { + const actual = await vi.importActual('@mastra/core'); + return { + ...actual, + Agent: vi.fn().mockImplementation(() => ({ + generate: (...args: any[]) => mockGenerateRef.current(...args), + })), + createStep: actual.createStep, + }; +}); + +// Now import after mocks are set up +import { RuntimeContext } from '@mastra/core/runtime-context'; +import type { AnalystRuntimeContext } from '../workflows/analyst-workflow'; +import { extractValuesSearchStep } from './extract-values-search-step'; + // Import the mocked functions import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search'; const mockGenerateEmbedding = generateEmbedding as ReturnType; const mockSearchValuesByEmbedding = searchValuesByEmbedding as ReturnType; -describe.skip('extractValuesSearchStep', () => { +// Access the mock generate function through the ref +const mockGenerate = mockGenerateRef.current; + +describe('extractValuesSearchStep', () => { beforeEach(() => { vi.clearAllMocks(); + // Set default mock behavior + mockGenerate.mockResolvedValue({ + object: { values: [] }, + }); }); afterEach(() => { @@ -48,19 +72,10 @@ describe.skip('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock the LLM response for keyword extraction - const mockAgentGenerate = vi.fn().mockResolvedValue({ + mockGenerate.mockResolvedValue({ object: { values: ['Red Bull', 'California'] }, }); - // Mock the values agent - vi.doMock('../../../src/steps/extract-values-search-step', async () => { - const actual = await vi.importActual('../../../src/steps/extract-values-search-step'); - return { - ...actual, - valuesAgent: { generate: mockAgentGenerate }, - }; - }); - mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); mockSearchValuesByEmbedding.mockResolvedValue([]); @@ -126,7 +141,7 @@ describe.skip('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock empty keyword extraction - const mockAgentGenerate = vi.fn().mockResolvedValue({ + mockGenerate.mockResolvedValue({ object: { values: [] }, }); @@ -179,6 +194,11 @@ describe.skip('extractValuesSearchStep', () => { const runtimeContext = new RuntimeContext(); runtimeContext.set('dataSourceId', 'test-datasource-id'); + // Mock successful keyword extraction + mockGenerate.mockResolvedValue({ + object: { values: ['Red Bull'] }, + }); + // Mock successful search mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); mockSearchValuesByEmbedding.mockResolvedValue(mockSearchResults); @@ -206,7 +226,7 @@ describe.skip('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock LLM extraction success but embedding failure - const mockAgentGenerate = vi.fn().mockResolvedValue({ + mockGenerate.mockResolvedValue({ object: { values: ['test keyword'] }, }); @@ -233,6 +253,11 @@ describe.skip('extractValuesSearchStep', () => { const runtimeContext = new RuntimeContext(); runtimeContext.set('dataSourceId', 'test-datasource-id'); + // Mock successful keyword extraction + mockGenerate.mockResolvedValue({ + object: { values: ['test keyword'] }, + }); + // Mock successful embedding but database failure mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database connection failed')); @@ -259,7 +284,7 @@ describe.skip('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock two keywords: one succeeds, one fails - const mockAgentGenerate = vi.fn().mockResolvedValue({ + mockGenerate.mockResolvedValue({ object: { values: ['keyword1', 'keyword2'] }, }); @@ -302,7 +327,7 @@ describe.skip('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock everything to fail - const mockAgentGenerate = vi.fn().mockRejectedValue(new Error('LLM failure')); + mockGenerate.mockRejectedValue(new Error('LLM failure')); mockGenerateEmbedding.mockRejectedValue(new Error('Embedding failure')); mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database failure')); @@ -344,9 +369,6 @@ describe.skip('extractValuesSearchStep', () => { }, ]; - mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); - mockSearchValuesByEmbedding.mockResolvedValue(mockSearchResults); - const inputData = { prompt: 'Test prompt', conversationHistory: [], @@ -355,6 +377,14 @@ describe.skip('extractValuesSearchStep', () => { const runtimeContext = new RuntimeContext(); runtimeContext.set('dataSourceId', 'test-datasource-id'); + // Mock successful keyword extraction + mockGenerate.mockResolvedValue({ + object: { values: ['Red Bull'] }, + }); + + mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); + mockSearchValuesByEmbedding.mockResolvedValue(mockSearchResults); + const result = await extractValuesSearchStep.execute({ inputData, runtimeContext, @@ -398,9 +428,6 @@ describe.skip('extractValuesSearchStep', () => { }, ]; - mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); - mockSearchValuesByEmbedding.mockResolvedValue(mockSearchResults); - const inputData = { prompt: 'Test prompt', conversationHistory: [], @@ -409,6 +436,14 @@ describe.skip('extractValuesSearchStep', () => { const runtimeContext = new RuntimeContext(); runtimeContext.set('dataSourceId', 'test-datasource-id'); + // Mock successful keyword extraction + mockGenerate.mockResolvedValue({ + object: { values: ['test'] }, + }); + + mockGenerateEmbedding.mockResolvedValue([1, 2, 3]); + mockSearchValuesByEmbedding.mockResolvedValue(mockSearchResults); + const result = await extractValuesSearchStep.execute({ inputData, runtimeContext, diff --git a/packages/ai/src/utils/models/haiku-3-5.ts b/packages/ai/src/utils/models/haiku-3-5.ts index dba7572cf..14b215519 100644 --- a/packages/ai/src/utils/models/haiku-3-5.ts +++ b/packages/ai/src/utils/models/haiku-3-5.ts @@ -3,41 +3,72 @@ import { createFallback } from './ai-fallback'; import { anthropicModel } from './providers/anthropic'; import { vertexModel } from './providers/vertex'; -// Build models array based on available credentials -const models: LanguageModelV1[] = []; +// Lazy initialization to allow mocking in tests +let _haiku35Instance: ReturnType | null = null; -// Only include Anthropic if API key is available -if (process.env.ANTHROPIC_API_KEY) { - try { - models.push(anthropicModel('claude-3-5-haiku-20241022')); - console.info('Haiku35: Anthropic model added to fallback chain'); - } catch (error) { - console.warn('Haiku35: Failed to initialize Anthropic model:', error); +function initializeHaiku35() { + if (_haiku35Instance) { + return _haiku35Instance; } -} -// // Only include Vertex if credentials are available -if (process.env.VERTEX_CLIENT_EMAIL && process.env.VERTEX_PRIVATE_KEY) { - try { - models.push(vertexModel('claude-3-5-haiku@20241022')); - console.info('Haiku35: Vertex AI model added to fallback chain'); - } catch (error) { - console.warn('Haiku35: Failed to initialize Vertex AI model:', error); + // Build models array based on available credentials + const models: LanguageModelV1[] = []; + + // Only include Anthropic if API key is available + if (process.env.ANTHROPIC_API_KEY) { + try { + models.push(anthropicModel('claude-3-5-haiku-20241022')); + console.info('Haiku35: Anthropic model added to fallback chain'); + } catch (error) { + console.warn('Haiku35: Failed to initialize Anthropic model:', error); + } } + + // Only include Vertex if credentials are available + if (process.env.VERTEX_CLIENT_EMAIL && process.env.VERTEX_PRIVATE_KEY) { + try { + models.push(vertexModel('claude-3-5-haiku@20241022')); + console.info('Haiku35: Vertex AI model added to fallback chain'); + } catch (error) { + console.warn('Haiku35: Failed to initialize Vertex AI model:', error); + } + } + + // Ensure we have at least one model + if (models.length === 0) { + throw new Error( + 'No AI models available. Please set either Vertex AI (VERTEX_CLIENT_EMAIL and VERTEX_PRIVATE_KEY) or Anthropic (ANTHROPIC_API_KEY) credentials.' + ); + } + + console.info(`Haiku35: Initialized with ${models.length} model(s) in fallback chain`); + + _haiku35Instance = createFallback({ + models, + modelResetInterval: 60000, + retryAfterOutput: true, + onError: (err) => console.error(`FALLBACK. Here is the error: ${err}`), + }); + + return _haiku35Instance; } -// Ensure we have at least one model -if (models.length === 0) { - throw new Error( - 'No AI models available. Please set either Vertex AI (VERTEX_CLIENT_EMAIL and VERTEX_PRIVATE_KEY) or Anthropic (ANTHROPIC_API_KEY) credentials.' - ); -} - -console.info(`Haiku35: Initialized with ${models.length} model(s) in fallback chain`); - -export const Haiku35 = createFallback({ - models, - modelResetInterval: 60000, - retryAfterOutput: true, - onError: (err) => console.error(`FALLBACK. Here is the error: ${err}`), +// Export a proxy that initializes on first use +export const Haiku35 = new Proxy({} as ReturnType, { + get(_target, prop, receiver) { + const instance = initializeHaiku35(); + return Reflect.get(instance, prop, receiver); + }, + has(_target, prop) { + const instance = initializeHaiku35(); + return Reflect.has(instance, prop); + }, + ownKeys(_target) { + const instance = initializeHaiku35(); + return Reflect.ownKeys(instance); + }, + getOwnPropertyDescriptor(_target, prop) { + const instance = initializeHaiku35(); + return Reflect.getOwnPropertyDescriptor(instance, prop); + }, }); diff --git a/packages/ai/src/utils/models/sonnet-4.ts b/packages/ai/src/utils/models/sonnet-4.ts index 40d887d46..536b3033e 100644 --- a/packages/ai/src/utils/models/sonnet-4.ts +++ b/packages/ai/src/utils/models/sonnet-4.ts @@ -3,44 +3,72 @@ import { createFallback } from './ai-fallback'; import { anthropicModel } from './providers/anthropic'; import { vertexModel } from './providers/vertex'; -// Build models array based on available credentials -const models: LanguageModelV1[] = []; +// Lazy initialization to allow mocking in tests +let _sonnet4Instance: ReturnType | null = null; -// Temporary dummy key for testing - REMOVE BEFORE COMMITTING -process.env.ANTHROPIC_API_KEY = 'dummy-key-for-testing'; - -// Only include Anthropic if API key is available -if (process.env.ANTHROPIC_API_KEY) { - try { - models.push(anthropicModel('claude-4-sonnet-20250514')); - console.info('Sonnet4: Anthropic model added to fallback chain'); - } catch (error) { - console.warn('Sonnet4: Failed to initialize Anthropic model:', error); +function initializeSonnet4() { + if (_sonnet4Instance) { + return _sonnet4Instance; } -} -// Only include Vertex if credentials are available -if (process.env.VERTEX_CLIENT_EMAIL && process.env.VERTEX_PRIVATE_KEY) { - try { - models.push(vertexModel('claude-sonnet-4@20250514')); - console.info('Sonnet4: Vertex AI model added to fallback chain'); - } catch (error) { - console.warn('Sonnet4: Failed to initialize Vertex AI model:', error); + // Build models array based on available credentials + const models: LanguageModelV1[] = []; + + // Only include Anthropic if API key is available + if (process.env.ANTHROPIC_API_KEY) { + try { + models.push(anthropicModel('claude-4-sonnet-20250514')); + console.info('Sonnet4: Anthropic model added to fallback chain'); + } catch (error) { + console.warn('Sonnet4: Failed to initialize Anthropic model:', error); + } } + + // Only include Vertex if credentials are available + if (process.env.VERTEX_CLIENT_EMAIL && process.env.VERTEX_PRIVATE_KEY) { + try { + models.push(vertexModel('claude-sonnet-4@20250514')); + console.info('Sonnet4: Vertex AI model added to fallback chain'); + } catch (error) { + console.warn('Sonnet4: Failed to initialize Vertex AI model:', error); + } + } + + // Ensure we have at least one model + if (models.length === 0) { + throw new Error( + 'No AI models available. Please set either Vertex AI (VERTEX_CLIENT_EMAIL and VERTEX_PRIVATE_KEY) or Anthropic (ANTHROPIC_API_KEY) credentials.' + ); + } + + console.info(`Sonnet4: Initialized with ${models.length} model(s) in fallback chain`); + + _sonnet4Instance = createFallback({ + models, + modelResetInterval: 60000, + retryAfterOutput: true, + onError: (err) => console.error(`FALLBACK. Here is the error: ${err}`), + }); + + return _sonnet4Instance; } -// Ensure we have at least one model -if (models.length === 0) { - throw new Error( - 'No AI models available. Please set either Vertex AI (VERTEX_CLIENT_EMAIL and VERTEX_PRIVATE_KEY) or Anthropic (ANTHROPIC_API_KEY) credentials.' - ); -} - -console.info(`Sonnet4: Initialized with ${models.length} model(s) in fallback chain`); - -export const Sonnet4 = createFallback({ - models, - modelResetInterval: 60000, - retryAfterOutput: true, - onError: (err) => console.error(`FALLBACK. Here is the error: ${err}`), +// Export a proxy that initializes on first use +export const Sonnet4 = new Proxy({} as ReturnType, { + get(_target, prop, receiver) { + const instance = initializeSonnet4(); + return Reflect.get(instance, prop, receiver); + }, + has(_target, prop) { + const instance = initializeSonnet4(); + return Reflect.has(instance, prop); + }, + ownKeys(_target) { + const instance = initializeSonnet4(); + return Reflect.ownKeys(instance); + }, + getOwnPropertyDescriptor(_target, prop) { + const instance = initializeSonnet4(); + return Reflect.getOwnPropertyDescriptor(instance, prop); + }, });