mirror of https://github.com/buster-so/buster.git
new structure for chat and values
This commit is contained in:
parent
9c2e4232ab
commit
5f43b2d074
|
@ -19,17 +19,16 @@ vi.mock('braintrust', () => ({
|
|||
wrapAISDKModel: vi.fn((model) => model),
|
||||
}));
|
||||
|
||||
// Create a ref object to hold the mock generate function
|
||||
const mockGenerateRef = { current: vi.fn() };
|
||||
// Mock the AI SDK
|
||||
vi.mock('ai', () => ({
|
||||
generateObject: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the Agent class from Mastra with the generate function
|
||||
// Mock Mastra
|
||||
vi.mock('@mastra/core', async () => {
|
||||
const actual = await vi.importActual('@mastra/core');
|
||||
return {
|
||||
...actual,
|
||||
Agent: vi.fn().mockImplementation(() => ({
|
||||
generate: (...args: any[]) => mockGenerateRef.current(...args),
|
||||
})),
|
||||
createStep: actual.createStep,
|
||||
};
|
||||
});
|
||||
|
@ -41,18 +40,17 @@ import { extractValuesSearchStep } from './extract-values-search-step';
|
|||
|
||||
// Import the mocked functions
|
||||
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
||||
import { generateObject } from 'ai';
|
||||
|
||||
const mockGenerateEmbedding = generateEmbedding as ReturnType<typeof vi.fn>;
|
||||
const mockSearchValuesByEmbedding = searchValuesByEmbedding as ReturnType<typeof vi.fn>;
|
||||
|
||||
// Access the mock generate function through the ref
|
||||
const mockGenerate = mockGenerateRef.current;
|
||||
const mockGenerateObject = generateObject as ReturnType<typeof vi.fn>;
|
||||
|
||||
describe('extractValuesSearchStep', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Set default mock behavior
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: [] },
|
||||
});
|
||||
});
|
||||
|
@ -72,7 +70,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock the LLM response for keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['Red Bull', 'California'] },
|
||||
});
|
||||
|
||||
|
@ -141,7 +139,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock empty keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: [] },
|
||||
});
|
||||
|
||||
|
@ -195,7 +193,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock successful keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['Red Bull'] },
|
||||
});
|
||||
|
||||
|
@ -226,7 +224,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock LLM extraction success but embedding failure
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['test keyword'] },
|
||||
});
|
||||
|
||||
|
@ -254,7 +252,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock successful keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['test keyword'] },
|
||||
});
|
||||
|
||||
|
@ -284,7 +282,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock two keywords: one succeeds, one fails
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['keyword1', 'keyword2'] },
|
||||
});
|
||||
|
||||
|
@ -327,7 +325,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock everything to fail
|
||||
mockGenerate.mockRejectedValue(new Error('LLM failure'));
|
||||
mockGenerateObject.mockRejectedValue(new Error('LLM failure'));
|
||||
mockGenerateEmbedding.mockRejectedValue(new Error('Embedding failure'));
|
||||
mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database failure'));
|
||||
|
||||
|
@ -378,7 +376,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock successful keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['Red Bull'] },
|
||||
});
|
||||
|
||||
|
@ -437,7 +435,7 @@ describe('extractValuesSearchStep', () => {
|
|||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||
|
||||
// Mock successful keyword extraction
|
||||
mockGenerate.mockResolvedValue({
|
||||
mockGenerateObject.mockResolvedValue({
|
||||
object: { values: ['test'] },
|
||||
});
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import type { StoredValueResult } from '@buster/stored-values';
|
||||
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
||||
import { Agent, createStep } from '@mastra/core';
|
||||
import { createStep } from '@mastra/core';
|
||||
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
||||
import { generateObject } from 'ai';
|
||||
import type { CoreMessage } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import { z } from 'zod';
|
||||
|
@ -12,6 +13,11 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow';
|
|||
|
||||
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
||||
|
||||
// Schema for what the LLM returns
|
||||
const llmOutputSchema = z.object({
|
||||
values: z.array(z.string()).describe('The values that the agent will search for.'),
|
||||
});
|
||||
|
||||
// Step output schema - what the step returns after performing the search
|
||||
export const extractValuesSearchOutputSchema = z.object({
|
||||
values: z.array(z.string()).describe('The values that the agent will search for.'),
|
||||
|
@ -231,12 +237,6 @@ async function searchStoredValues(
|
|||
}
|
||||
}
|
||||
|
||||
const valuesAgent = new Agent({
|
||||
name: 'Extract Values',
|
||||
instructions: extractValuesInstructions,
|
||||
model: Haiku35,
|
||||
});
|
||||
|
||||
const extractValuesSearchStepExecution = async ({
|
||||
inputData,
|
||||
runtimeContext,
|
||||
|
@ -264,12 +264,19 @@ const extractValuesSearchStepExecution = async ({
|
|||
try {
|
||||
const tracedValuesExtraction = wrapTraced(
|
||||
async () => {
|
||||
const response = await valuesAgent.generate(messages, {
|
||||
maxSteps: 0,
|
||||
output: extractValuesSearchOutputSchema,
|
||||
const { object } = await generateObject({
|
||||
model: Haiku35,
|
||||
schema: llmOutputSchema,
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: extractValuesInstructions,
|
||||
},
|
||||
...messages,
|
||||
],
|
||||
});
|
||||
|
||||
return response.object;
|
||||
return object;
|
||||
},
|
||||
{
|
||||
name: 'Extract Values',
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import { updateChat, updateMessage } from '@buster/database';
|
||||
import { Agent, createStep } from '@mastra/core';
|
||||
import { createStep } from '@mastra/core';
|
||||
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
||||
import { generateObject } from 'ai';
|
||||
import type { CoreMessage } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import { z } from 'zod';
|
||||
|
@ -11,6 +12,12 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow';
|
|||
|
||||
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
||||
|
||||
// Schema for what the LLM returns
|
||||
const llmOutputSchema = z.object({
|
||||
title: z.string().describe('The title for the chat.'),
|
||||
});
|
||||
|
||||
// Schema for what the step returns (includes pass-through data)
|
||||
export const generateChatTitleOutputSchema = z.object({
|
||||
title: z.string().describe('The title for the chat.'),
|
||||
// Pass through dashboard context
|
||||
|
@ -28,13 +35,9 @@ export const generateChatTitleOutputSchema = z.object({
|
|||
|
||||
const generateChatTitleInstructions = `
|
||||
I am a chat title generator that is responsible for generating a title for the chat.
|
||||
`;
|
||||
|
||||
const todosAgent = new Agent({
|
||||
name: 'Extract Values',
|
||||
instructions: generateChatTitleInstructions,
|
||||
model: Haiku35,
|
||||
});
|
||||
The title should be 3-8 words, capturing the main topic or intent of the conversation.
|
||||
`;
|
||||
|
||||
const generateChatTitleExecution = async ({
|
||||
inputData,
|
||||
|
@ -63,12 +66,19 @@ const generateChatTitleExecution = async ({
|
|||
try {
|
||||
const tracedChatTitle = wrapTraced(
|
||||
async () => {
|
||||
const response = await todosAgent.generate(messages, {
|
||||
maxSteps: 0,
|
||||
output: generateChatTitleOutputSchema,
|
||||
const { object } = await generateObject({
|
||||
model: Haiku35,
|
||||
schema: llmOutputSchema,
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: generateChatTitleInstructions,
|
||||
},
|
||||
...messages,
|
||||
],
|
||||
});
|
||||
|
||||
return response.object;
|
||||
return object;
|
||||
},
|
||||
{
|
||||
name: 'Generate Chat Title',
|
||||
|
|
Loading…
Reference in New Issue