mirror of https://github.com/buster-so/buster.git
Merge pull request #617 from buster-so/dallin/hotfix-chat-title-and-extract-values
new structure for chat and values
This commit is contained in:
commit
ab1efe1475
|
@ -19,17 +19,16 @@ vi.mock('braintrust', () => ({
|
||||||
wrapAISDKModel: vi.fn((model) => model),
|
wrapAISDKModel: vi.fn((model) => model),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Create a ref object to hold the mock generate function
|
// Mock the AI SDK
|
||||||
const mockGenerateRef = { current: vi.fn() };
|
vi.mock('ai', () => ({
|
||||||
|
generateObject: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
// Mock the Agent class from Mastra with the generate function
|
// Mock Mastra
|
||||||
vi.mock('@mastra/core', async () => {
|
vi.mock('@mastra/core', async () => {
|
||||||
const actual = await vi.importActual('@mastra/core');
|
const actual = await vi.importActual('@mastra/core');
|
||||||
return {
|
return {
|
||||||
...actual,
|
...actual,
|
||||||
Agent: vi.fn().mockImplementation(() => ({
|
|
||||||
generate: (...args: any[]) => mockGenerateRef.current(...args),
|
|
||||||
})),
|
|
||||||
createStep: actual.createStep,
|
createStep: actual.createStep,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
@ -41,18 +40,17 @@ import { extractValuesSearchStep } from './extract-values-search-step';
|
||||||
|
|
||||||
// Import the mocked functions
|
// Import the mocked functions
|
||||||
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
||||||
|
import { generateObject } from 'ai';
|
||||||
|
|
||||||
const mockGenerateEmbedding = generateEmbedding as ReturnType<typeof vi.fn>;
|
const mockGenerateEmbedding = generateEmbedding as ReturnType<typeof vi.fn>;
|
||||||
const mockSearchValuesByEmbedding = searchValuesByEmbedding as ReturnType<typeof vi.fn>;
|
const mockSearchValuesByEmbedding = searchValuesByEmbedding as ReturnType<typeof vi.fn>;
|
||||||
|
const mockGenerateObject = generateObject as ReturnType<typeof vi.fn>;
|
||||||
// Access the mock generate function through the ref
|
|
||||||
const mockGenerate = mockGenerateRef.current;
|
|
||||||
|
|
||||||
describe('extractValuesSearchStep', () => {
|
describe('extractValuesSearchStep', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
// Set default mock behavior
|
// Set default mock behavior
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: [] },
|
object: { values: [] },
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -72,7 +70,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock the LLM response for keyword extraction
|
// Mock the LLM response for keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['Red Bull', 'California'] },
|
object: { values: ['Red Bull', 'California'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -141,7 +139,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock empty keyword extraction
|
// Mock empty keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: [] },
|
object: { values: [] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -195,7 +193,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock successful keyword extraction
|
// Mock successful keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['Red Bull'] },
|
object: { values: ['Red Bull'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -226,7 +224,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock LLM extraction success but embedding failure
|
// Mock LLM extraction success but embedding failure
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['test keyword'] },
|
object: { values: ['test keyword'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -254,7 +252,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock successful keyword extraction
|
// Mock successful keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['test keyword'] },
|
object: { values: ['test keyword'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -284,7 +282,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock two keywords: one succeeds, one fails
|
// Mock two keywords: one succeeds, one fails
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['keyword1', 'keyword2'] },
|
object: { values: ['keyword1', 'keyword2'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -327,7 +325,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock everything to fail
|
// Mock everything to fail
|
||||||
mockGenerate.mockRejectedValue(new Error('LLM failure'));
|
mockGenerateObject.mockRejectedValue(new Error('LLM failure'));
|
||||||
mockGenerateEmbedding.mockRejectedValue(new Error('Embedding failure'));
|
mockGenerateEmbedding.mockRejectedValue(new Error('Embedding failure'));
|
||||||
mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database failure'));
|
mockSearchValuesByEmbedding.mockRejectedValue(new Error('Database failure'));
|
||||||
|
|
||||||
|
@ -378,7 +376,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock successful keyword extraction
|
// Mock successful keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['Red Bull'] },
|
object: { values: ['Red Bull'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -437,7 +435,7 @@ describe('extractValuesSearchStep', () => {
|
||||||
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
runtimeContext.set('dataSourceId', 'test-datasource-id');
|
||||||
|
|
||||||
// Mock successful keyword extraction
|
// Mock successful keyword extraction
|
||||||
mockGenerate.mockResolvedValue({
|
mockGenerateObject.mockResolvedValue({
|
||||||
object: { values: ['test'] },
|
object: { values: ['test'] },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import type { StoredValueResult } from '@buster/stored-values';
|
import type { StoredValueResult } from '@buster/stored-values';
|
||||||
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
import { generateEmbedding, searchValuesByEmbedding } from '@buster/stored-values/search';
|
||||||
import { Agent, createStep } from '@mastra/core';
|
import { createStep } from '@mastra/core';
|
||||||
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
||||||
|
import { generateObject } from 'ai';
|
||||||
import type { CoreMessage } from 'ai';
|
import type { CoreMessage } from 'ai';
|
||||||
import { wrapTraced } from 'braintrust';
|
import { wrapTraced } from 'braintrust';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
@ -12,6 +13,11 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow';
|
||||||
|
|
||||||
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
||||||
|
|
||||||
|
// Schema for what the LLM returns
|
||||||
|
const llmOutputSchema = z.object({
|
||||||
|
values: z.array(z.string()).describe('The values that the agent will search for.'),
|
||||||
|
});
|
||||||
|
|
||||||
// Step output schema - what the step returns after performing the search
|
// Step output schema - what the step returns after performing the search
|
||||||
export const extractValuesSearchOutputSchema = z.object({
|
export const extractValuesSearchOutputSchema = z.object({
|
||||||
values: z.array(z.string()).describe('The values that the agent will search for.'),
|
values: z.array(z.string()).describe('The values that the agent will search for.'),
|
||||||
|
@ -231,12 +237,6 @@ async function searchStoredValues(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const valuesAgent = new Agent({
|
|
||||||
name: 'Extract Values',
|
|
||||||
instructions: extractValuesInstructions,
|
|
||||||
model: Haiku35,
|
|
||||||
});
|
|
||||||
|
|
||||||
const extractValuesSearchStepExecution = async ({
|
const extractValuesSearchStepExecution = async ({
|
||||||
inputData,
|
inputData,
|
||||||
runtimeContext,
|
runtimeContext,
|
||||||
|
@ -264,12 +264,19 @@ const extractValuesSearchStepExecution = async ({
|
||||||
try {
|
try {
|
||||||
const tracedValuesExtraction = wrapTraced(
|
const tracedValuesExtraction = wrapTraced(
|
||||||
async () => {
|
async () => {
|
||||||
const response = await valuesAgent.generate(messages, {
|
const { object } = await generateObject({
|
||||||
maxSteps: 0,
|
model: Haiku35,
|
||||||
output: extractValuesSearchOutputSchema,
|
schema: llmOutputSchema,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: extractValuesInstructions,
|
||||||
|
},
|
||||||
|
...messages,
|
||||||
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
return response.object;
|
return object;
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'Extract Values',
|
name: 'Extract Values',
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import { updateChat, updateMessage } from '@buster/database';
|
import { updateChat, updateMessage } from '@buster/database';
|
||||||
import { Agent, createStep } from '@mastra/core';
|
import { createStep } from '@mastra/core';
|
||||||
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
||||||
|
import { generateObject } from 'ai';
|
||||||
import type { CoreMessage } from 'ai';
|
import type { CoreMessage } from 'ai';
|
||||||
import { wrapTraced } from 'braintrust';
|
import { wrapTraced } from 'braintrust';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
@ -11,6 +12,12 @@ import type { AnalystRuntimeContext } from '../workflows/analyst-workflow';
|
||||||
|
|
||||||
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
||||||
|
|
||||||
|
// Schema for what the LLM returns
|
||||||
|
const llmOutputSchema = z.object({
|
||||||
|
title: z.string().describe('The title for the chat.'),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Schema for what the step returns (includes pass-through data)
|
||||||
export const generateChatTitleOutputSchema = z.object({
|
export const generateChatTitleOutputSchema = z.object({
|
||||||
title: z.string().describe('The title for the chat.'),
|
title: z.string().describe('The title for the chat.'),
|
||||||
// Pass through dashboard context
|
// Pass through dashboard context
|
||||||
|
@ -28,13 +35,9 @@ export const generateChatTitleOutputSchema = z.object({
|
||||||
|
|
||||||
const generateChatTitleInstructions = `
|
const generateChatTitleInstructions = `
|
||||||
I am a chat title generator that is responsible for generating a title for the chat.
|
I am a chat title generator that is responsible for generating a title for the chat.
|
||||||
`;
|
|
||||||
|
|
||||||
const todosAgent = new Agent({
|
The title should be 3-8 words, capturing the main topic or intent of the conversation.
|
||||||
name: 'Extract Values',
|
`;
|
||||||
instructions: generateChatTitleInstructions,
|
|
||||||
model: Haiku35,
|
|
||||||
});
|
|
||||||
|
|
||||||
const generateChatTitleExecution = async ({
|
const generateChatTitleExecution = async ({
|
||||||
inputData,
|
inputData,
|
||||||
|
@ -63,12 +66,19 @@ const generateChatTitleExecution = async ({
|
||||||
try {
|
try {
|
||||||
const tracedChatTitle = wrapTraced(
|
const tracedChatTitle = wrapTraced(
|
||||||
async () => {
|
async () => {
|
||||||
const response = await todosAgent.generate(messages, {
|
const { object } = await generateObject({
|
||||||
maxSteps: 0,
|
model: Haiku35,
|
||||||
output: generateChatTitleOutputSchema,
|
schema: llmOutputSchema,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: generateChatTitleInstructions,
|
||||||
|
},
|
||||||
|
...messages,
|
||||||
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
return response.object;
|
return object;
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'Generate Chat Title',
|
name: 'Generate Chat Title',
|
||||||
|
|
Loading…
Reference in New Issue