diff --git a/packages/ai/src/steps/extract-values-search-step.test.ts b/packages/ai/src/steps/extract-values-search-step.test.ts index 1492fe09e..31722a570 100644 --- a/packages/ai/src/steps/extract-values-search-step.test.ts +++ b/packages/ai/src/steps/extract-values-search-step.test.ts @@ -19,17 +19,16 @@ vi.mock('braintrust', () => ({ wrapAISDKModel: vi.fn((model) => model), })); -// Create a ref object to hold the mock generate function -const mockGenerateRef = { current: vi.fn() }; +// Mock the AI SDK +vi.mock('ai', () => ({ + generateObject: vi.fn(), +})); -// Mock the Agent class from Mastra with the generate function +// Mock Mastra vi.mock('@mastra/core', async () => { const actual = await vi.importActual('@mastra/core'); return { ...actual, - Agent: vi.fn().mockImplementation(() => ({ - generate: (...args: any[]) => mockGenerateRef.current(...args), - })), createStep: actual.createStep, }; }); @@ -41,18 +40,17 @@ import { extractValuesSearchStep } from './extract-values-search-step'; // Import the mocked functions import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search'; +import { generateObject } from 'ai'; const mockGenerateEmbedding = generateEmbedding as ReturnType; const mockSearchValuesByEmbedding = searchValuesByEmbedding as ReturnType; - -// Access the mock generate function through the ref -const mockGenerate = mockGenerateRef.current; +const mockGenerateObject = generateObject as ReturnType; describe('extractValuesSearchStep', () => { beforeEach(() => { vi.clearAllMocks(); // Set default mock behavior - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: [] }, }); }); @@ -72,7 +70,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock the LLM response for keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['Red Bull', 'California'] }, }); @@ -141,7 +139,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock empty keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: [] }, }); @@ -195,7 +193,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock successful keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['Red Bull'] }, }); @@ -226,7 +224,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock LLM extraction success but embedding failure - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['test keyword'] }, }); @@ -254,7 +252,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock successful keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['test keyword'] }, }); @@ -284,7 +282,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock two keywords: one succeeds, one fails - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['keyword1', 'keyword2'] }, }); @@ -327,7 +325,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock everything to fail - mockGenerate.mockRejectedValue(new Error('LLM failure')); + mockGenerateObject.mockRejectedValue(new Error('LLM failure')); mockGenerateEmbedding.mockRejectedValue(new Error('Embedding failure')); mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database failure')); @@ -378,7 +376,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock successful keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['Red Bull'] }, }); @@ -437,7 +435,7 @@ describe('extractValuesSearchStep', () => { runtimeContext.set('dataSourceId', 'test-datasource-id'); // Mock successful keyword extraction - mockGenerate.mockResolvedValue({ + mockGenerateObject.mockResolvedValue({ object: { values: ['test'] }, }); diff --git a/packages/ai/src/steps/extract-values-search-step.ts b/packages/ai/src/steps/extract-values-search-step.ts index ed7e49894..c5e38aae7 100644 --- a/packages/ai/src/steps/extract-values-search-step.ts +++ b/packages/ai/src/steps/extract-values-search-step.ts @@ -1,7 +1,8 @@ import type { StoredValueResult } from '@buster/stored-values'; import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search'; -import { Agent, createStep } from '@mastra/core'; +import { createStep } from '@mastra/core'; import type { RuntimeContext } from '@mastra/core/runtime-context'; +import { generateObject } from 'ai'; import type { CoreMessage } from 'ai'; import { wrapTraced } from 'braintrust'; import { z } from 'zod'; @@ -12,6 +13,11 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow'; const inputSchema = thinkAndPrepWorkflowInputSchema; +// Schema for what the LLM returns +const llmOutputSchema = z.object({ + values: z.array(z.string()).describe('The values that the agent will search for.'), +}); + // Step output schema - what the step returns after performing the search export const extractValuesSearchOutputSchema = z.object({ values: z.array(z.string()).describe('The values that the agent will search for.'), @@ -231,12 +237,6 @@ async function searchStoredValues( } } -const valuesAgent = new Agent({ - name: 'Extract Values', - instructions: extractValuesInstructions, - model: Haiku35, -}); - const extractValuesSearchStepExecution = async ({ inputData, runtimeContext, @@ -264,12 +264,19 @@ const extractValuesSearchStepExecution = async ({ try { const tracedValuesExtraction = wrapTraced( async () => { - const response = await valuesAgent.generate(messages, { - maxSteps: 0, - output: extractValuesSearchOutputSchema, + const { object } = await generateObject({ + model: Haiku35, + schema: llmOutputSchema, + messages: [ + { + role: 'system', + content: extractValuesInstructions, + }, + ...messages, + ], }); - return response.object; + return object; }, { name: 'Extract Values', diff --git a/packages/ai/src/steps/generate-chat-title-step.ts b/packages/ai/src/steps/generate-chat-title-step.ts index 728a36094..b43ae9f11 100644 --- a/packages/ai/src/steps/generate-chat-title-step.ts +++ b/packages/ai/src/steps/generate-chat-title-step.ts @@ -1,6 +1,7 @@ import { updateChat, updateMessage } from '@buster/database'; -import { Agent, createStep } from '@mastra/core'; +import { createStep } from '@mastra/core'; import type { RuntimeContext } from '@mastra/core/runtime-context'; +import { generateObject } from 'ai'; import type { CoreMessage } from 'ai'; import { wrapTraced } from 'braintrust'; import { z } from 'zod'; @@ -11,6 +12,12 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow'; const inputSchema = thinkAndPrepWorkflowInputSchema; +// Schema for what the LLM returns +const llmOutputSchema = z.object({ + title: z.string().describe('The title for the chat.'), +}); + +// Schema for what the step returns (includes pass-through data) export const generateChatTitleOutputSchema = z.object({ title: z.string().describe('The title for the chat.'), // Pass through dashboard context @@ -28,13 +35,9 @@ export const generateChatTitleOutputSchema = z.object({ const generateChatTitleInstructions = ` I am a chat title generator that is responsible for generating a title for the chat. -`; -const todosAgent = new Agent({ - name: 'Extract Values', - instructions: generateChatTitleInstructions, - model: Haiku35, -}); +The title should be 3-8 words, capturing the main topic or intent of the conversation. +`; const generateChatTitleExecution = async ({ inputData, @@ -63,12 +66,19 @@ const generateChatTitleExecution = async ({ try { const tracedChatTitle = wrapTraced( async () => { - const response = await todosAgent.generate(messages, { - maxSteps: 0, - output: generateChatTitleOutputSchema, + const { object } = await generateObject({ + model: Haiku35, + schema: llmOutputSchema, + messages: [ + { + role: 'system', + content: generateChatTitleInstructions, + }, + ...messages, + ], }); - return response.object; + return object; }, { name: 'Generate Chat Title',