mirror of https://github.com/buster-so/buster.git
Merge pull request #869 from buster-so/wells-bus-1707-user-personalization-endpoints
Personalization
This commit is contained in:
commit
9f31ab4ff3
|
@ -0,0 +1,67 @@
|
|||
import { getUserInformation, updateUser } from '@buster/database';
|
||||
import {
|
||||
GetUserByIdRequestSchema,
|
||||
type GetUserByIdResponse,
|
||||
UserPatchRequestSchema,
|
||||
type UserPatchResponse,
|
||||
} from '@buster/server-shared/user';
|
||||
import { zValidator } from '@hono/zod-validator';
|
||||
import { Hono } from 'hono';
|
||||
import { HTTPException } from 'hono/http-exception';
|
||||
import { standardErrorHandler } from '../../../../utils/response';
|
||||
|
||||
const app = new Hono()
|
||||
.patch(
|
||||
'/',
|
||||
zValidator('param', GetUserByIdRequestSchema),
|
||||
zValidator('json', UserPatchRequestSchema),
|
||||
async (c) => {
|
||||
const userId = c.req.param('id');
|
||||
const authenticatedUser = c.get('busterUser');
|
||||
|
||||
if (authenticatedUser.id !== userId) {
|
||||
throw new HTTPException(403, {
|
||||
message: 'You are not authorized to update this user',
|
||||
});
|
||||
}
|
||||
|
||||
const { personalizationEnabled, personalizationConfig, name } = c.req.valid('json');
|
||||
|
||||
// Check for undefined because empty strings are valid updates
|
||||
if (
|
||||
personalizationEnabled === undefined &&
|
||||
personalizationConfig === undefined &&
|
||||
name === undefined
|
||||
) {
|
||||
throw new HTTPException(400, {
|
||||
message: 'No fields to update',
|
||||
});
|
||||
}
|
||||
|
||||
const currentUser: GetUserByIdResponse = await getUserInformation(userId);
|
||||
const updatedPersonalizationConfig = currentUser.personalizationConfig;
|
||||
|
||||
if (personalizationConfig?.currentRole !== undefined) {
|
||||
updatedPersonalizationConfig.currentRole = personalizationConfig.currentRole;
|
||||
}
|
||||
if (personalizationConfig?.customInstructions !== undefined) {
|
||||
updatedPersonalizationConfig.customInstructions = personalizationConfig.customInstructions;
|
||||
}
|
||||
if (personalizationConfig?.additionalInformation !== undefined) {
|
||||
updatedPersonalizationConfig.additionalInformation =
|
||||
personalizationConfig.additionalInformation;
|
||||
}
|
||||
|
||||
const updatedUser: UserPatchResponse = await updateUser({
|
||||
userId,
|
||||
name,
|
||||
personalizationEnabled,
|
||||
personalizationConfig: updatedPersonalizationConfig,
|
||||
});
|
||||
|
||||
return c.json(updatedUser);
|
||||
}
|
||||
)
|
||||
.onError(standardErrorHandler);
|
||||
|
||||
export default app;
|
|
@ -2,7 +2,9 @@ import { Hono } from 'hono';
|
|||
import { requireAuth } from '../../../middleware/auth';
|
||||
import GET from './GET';
|
||||
import POST from './POST';
|
||||
import userIdRoute from './[id]/GET';
|
||||
import userIdGet from './[id]/GET';
|
||||
import userIdPatch from './[id]/PATCH';
|
||||
import userIdSuggestedPrompts from './[id]/suggested-prompts/GET';
|
||||
|
||||
const app = new Hono()
|
||||
// Apply authentication globally to ALL routes in this router
|
||||
|
@ -10,6 +12,8 @@ const app = new Hono()
|
|||
// Mount the modular routes
|
||||
.route('/', GET)
|
||||
.route('/', POST)
|
||||
.route('/:id', userIdRoute);
|
||||
.route('/:id', userIdGet)
|
||||
.route('/:id', userIdPatch)
|
||||
.route('/:id/suggested-prompts', userIdSuggestedPrompts);
|
||||
|
||||
export default app;
|
||||
|
|
|
@ -11,6 +11,7 @@ import {
|
|||
getOrganizationAnalystDoc,
|
||||
getOrganizationDataSource,
|
||||
getOrganizationDocs,
|
||||
getUserPersonalization,
|
||||
} from '@buster/database';
|
||||
|
||||
// Access control imports
|
||||
|
@ -306,6 +307,11 @@ export const analystAgentTask: ReturnType<
|
|||
}
|
||||
});
|
||||
|
||||
// Fetch user personalization config
|
||||
const userPersonalizationConfigPromise = messageContextPromise.then((context) =>
|
||||
getUserPersonalization(context.userId)
|
||||
);
|
||||
|
||||
// Fetch Braintrust metadata in parallel
|
||||
const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id });
|
||||
|
||||
|
@ -328,6 +334,7 @@ export const analystAgentTask: ReturnType<
|
|||
braintrustMetadata,
|
||||
analystInstructions,
|
||||
organizationDocs,
|
||||
userPersonalizationConfig,
|
||||
] = await Promise.all([
|
||||
messageContextPromise,
|
||||
conversationHistoryPromise,
|
||||
|
@ -336,6 +343,7 @@ export const analystAgentTask: ReturnType<
|
|||
braintrustMetadataPromise,
|
||||
analystInstructionsPromise,
|
||||
organizationDocsPromise,
|
||||
userPersonalizationConfigPromise,
|
||||
]);
|
||||
|
||||
const dataLoadEnd = Date.now();
|
||||
|
@ -387,6 +395,7 @@ export const analystAgentTask: ReturnType<
|
|||
datasets,
|
||||
analystInstructions: analystInstructions || undefined,
|
||||
organizationDocs,
|
||||
userPersonalizationConfig,
|
||||
};
|
||||
|
||||
logger.log('Workflow input prepared', {
|
||||
|
|
|
@ -16,6 +16,7 @@ describe('Analyst Agent Integration Tests', () => {
|
|||
messageId: '123',
|
||||
workflowStartTime: Date.now(),
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
});
|
||||
|
||||
const streamResult = await analystAgent.stream({
|
||||
|
|
|
@ -49,6 +49,9 @@ export const AnalystAgentOptionsSchema = z.object({
|
|||
})
|
||||
)
|
||||
.optional(),
|
||||
userPersonalizationMessageContent: z
|
||||
.string()
|
||||
.describe('Custom user personalization in message content'),
|
||||
});
|
||||
|
||||
export const AnalystStreamOptionsSchema = z.object({
|
||||
|
@ -61,7 +64,8 @@ export type AnalystAgentOptions = z.infer<typeof AnalystAgentOptionsSchema>;
|
|||
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
|
||||
|
||||
export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||
const { datasets, analystInstructions, organizationDocs } = analystAgentOptions;
|
||||
const { datasets, analystInstructions, organizationDocs, userPersonalizationMessageContent } =
|
||||
analystAgentOptions;
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
|
@ -131,6 +135,15 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
} as ModelMessage)
|
||||
: null;
|
||||
|
||||
// Create user personalization system message
|
||||
const userPersonalizationSystemMessage = userPersonalizationMessageContent
|
||||
? ({
|
||||
role: 'system',
|
||||
content: userPersonalizationMessageContent,
|
||||
providerOptions: DEFAULT_ANTHROPIC_OPTIONS,
|
||||
} as ModelMessage)
|
||||
: null;
|
||||
|
||||
return wrapTraced(
|
||||
() =>
|
||||
streamText({
|
||||
|
@ -154,6 +167,7 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
datasetsSystemMessage,
|
||||
...(docsSystemMessage ? [docsSystemMessage] : []),
|
||||
...(analystInstructionsMessage ? [analystInstructionsMessage] : []),
|
||||
...(userPersonalizationSystemMessage ? [userPersonalizationSystemMessage] : []),
|
||||
...messages,
|
||||
],
|
||||
stopWhen: STOP_CONDITIONS,
|
||||
|
|
|
@ -19,6 +19,7 @@ describe('Think and Prep Agent Integration Tests', () => {
|
|||
dataSourceId: 'test-data-source-123',
|
||||
dataSourceSyntax: 'postgresql',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
});
|
||||
|
||||
const streamResult = await thinkAndPrepAgent.stream({
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import type { PermissionedDataset } from '@buster/access-controls';
|
||||
import { messageAnalysisModeEnum } from '@buster/database';
|
||||
import {
|
||||
UserPersonalizationConfigSchema,
|
||||
type UserPersonalizationConfigType,
|
||||
messageAnalysisModeEnum,
|
||||
} from '@buster/database';
|
||||
import { type ModelMessage, hasToolCall, stepCountIs, streamText } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import z from 'zod';
|
||||
|
@ -70,6 +74,9 @@ export const ThinkAndPrepAgentOptionsSchema = z.object({
|
|||
)
|
||||
.optional()
|
||||
.describe('Organization data catalog documentation.'),
|
||||
userPersonalizationMessageContent: z
|
||||
.string()
|
||||
.describe('Custom user personalization in message content'),
|
||||
});
|
||||
|
||||
export const ThinkAndPrepStreamOptionsSchema = z.object({
|
||||
|
@ -82,8 +89,14 @@ export type ThinkAndPrepAgentOptions = z.infer<typeof ThinkAndPrepAgentOptionsSc
|
|||
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
|
||||
|
||||
export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) {
|
||||
const { messageId, datasets, workflowStartTime, analystInstructions, organizationDocs } =
|
||||
thinkAndPrepAgentSchema;
|
||||
const {
|
||||
messageId,
|
||||
datasets,
|
||||
workflowStartTime,
|
||||
analystInstructions,
|
||||
organizationDocs,
|
||||
userPersonalizationMessageContent,
|
||||
} = thinkAndPrepAgentSchema;
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
|
@ -172,6 +185,15 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
|||
} as ModelMessage)
|
||||
: null;
|
||||
|
||||
// Create user personalization system message
|
||||
const userPersonalizationSystemMessage = userPersonalizationMessageContent
|
||||
? ({
|
||||
role: 'system',
|
||||
content: userPersonalizationMessageContent,
|
||||
providerOptions: DEFAULT_ANTHROPIC_OPTIONS,
|
||||
} as ModelMessage)
|
||||
: null;
|
||||
|
||||
return wrapTraced(
|
||||
() =>
|
||||
streamText({
|
||||
|
@ -193,6 +215,7 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
|||
datasetsSystemMessage,
|
||||
...(docsSystemMessage ? [docsSystemMessage] : []),
|
||||
...(analystInstructionsMessage ? [analystInstructionsMessage] : []),
|
||||
...(userPersonalizationSystemMessage ? [userPersonalizationSystemMessage] : []),
|
||||
...messages,
|
||||
],
|
||||
stopWhen: STOP_CONDITIONS,
|
||||
|
|
|
@ -43,6 +43,7 @@ describe('runAnalystAgentStep', () => {
|
|||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -79,6 +80,7 @@ describe('runAnalystAgentStep', () => {
|
|||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -110,6 +112,7 @@ describe('runAnalystAgentStep', () => {
|
|||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -139,6 +142,7 @@ describe('runAnalystAgentStep', () => {
|
|||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
|
|
@ -9,12 +9,23 @@ import {
|
|||
* Factory function that creates a finish handler for TODO creation
|
||||
* Called when streaming completes to finalize the reasoning message
|
||||
*/
|
||||
export function createTodosStepFinish(todosState: CreateTodosState, context: CreateTodosContext) {
|
||||
export function createTodosStepFinish(
|
||||
todosState: CreateTodosState,
|
||||
context: CreateTodosContext,
|
||||
injectPersonalizationTodo: boolean
|
||||
) {
|
||||
return async function todosStepFinish(result: CreateTodosInput): Promise<void> {
|
||||
const personalizationStaticToDo = `\n[ ] Determine if any of the user's personalized instructions are relevant to this question`;
|
||||
|
||||
// Update state with final values
|
||||
todosState.todos = result.todos;
|
||||
todosState.is_complete = true;
|
||||
|
||||
// Inject the personalization todo if needed
|
||||
if (injectPersonalizationTodo) {
|
||||
todosState.todos += personalizationStaticToDo;
|
||||
}
|
||||
|
||||
// Create final reasoning message with completed status
|
||||
const todosReasoningEntry = createTodosReasoningMessage(todosState);
|
||||
const todosRawMessages = createTodosRawLlmMessageEntry(todosState);
|
||||
|
|
|
@ -10,6 +10,9 @@ import { getCreateTodosSystemMessage } from './get-create-todos-system-message';
|
|||
export const createTodosParamsSchema = z.object({
|
||||
messages: z.array(z.custom<ModelMessage>()).describe('The conversation history'),
|
||||
messageId: z.string().describe('The message ID for database updates'),
|
||||
shouldInjectUserPersonalizationTodo: z
|
||||
.boolean()
|
||||
.describe('Whether to inject the user personalization todo'),
|
||||
});
|
||||
|
||||
export const createTodosResultSchema = z.object({
|
||||
|
@ -49,6 +52,7 @@ export type CreateTodosContext = z.infer<typeof createTodosContextSchema>;
|
|||
export type CreateTodosState = z.infer<typeof createTodosStateSchema>;
|
||||
export type CreateTodosInput = z.infer<typeof llmOutputSchema>;
|
||||
|
||||
import { UserPersonalizationConfigSchema } from '@buster/database';
|
||||
import { createTodosStepDelta } from './create-todos-step-delta';
|
||||
import { createTodosStepFinish } from './create-todos-step-finish';
|
||||
import { createTodosStepStart } from './create-todos-step-start';
|
||||
|
@ -58,7 +62,8 @@ import { createTodosStepStart } from './create-todos-step-start';
|
|||
*/
|
||||
async function generateTodosWithLLM(
|
||||
messages: ModelMessage[],
|
||||
context: CreateTodosContext
|
||||
context: CreateTodosContext,
|
||||
injectPersonalizationTodo: boolean
|
||||
): Promise<string> {
|
||||
try {
|
||||
// Prepare messages for the LLM
|
||||
|
@ -80,7 +85,7 @@ async function generateTodosWithLLM(
|
|||
// Create streaming handlers
|
||||
const onStreamStart = createTodosStepStart(state, context);
|
||||
const onTextDelta = createTodosStepDelta(state, context);
|
||||
const onStreamFinish = createTodosStepFinish(state, context);
|
||||
const onStreamFinish = createTodosStepFinish(state, context, injectPersonalizationTodo);
|
||||
|
||||
const tracedTodosGeneration = wrapTraced(
|
||||
async () => {
|
||||
|
@ -140,7 +145,11 @@ export async function runCreateTodosStep(params: CreateTodosParams): Promise<Cre
|
|||
messageId: params.messageId,
|
||||
};
|
||||
|
||||
const todos = await generateTodosWithLLM(params.messages, context);
|
||||
const todos = await generateTodosWithLLM(
|
||||
params.messages,
|
||||
context,
|
||||
params.shouldInjectUserPersonalizationTodo
|
||||
);
|
||||
|
||||
// Generate a unique ID for this tool call
|
||||
const toolCallId = `create_todos_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||
|
|
|
@ -14,7 +14,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -34,7 +38,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -56,7 +64,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -75,7 +87,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -98,7 +114,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -117,7 +137,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -135,7 +159,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -155,7 +183,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -175,7 +207,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -194,7 +230,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -222,7 +262,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -248,7 +292,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -268,7 +316,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -287,7 +339,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -306,7 +362,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -325,7 +385,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -358,7 +422,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -377,7 +445,11 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -397,7 +469,11 @@ describe('create-todos-step integration', () => {
|
|||
];
|
||||
|
||||
// Even if internally aborted, should return valid structure
|
||||
const result = await runCreateTodosStep({ messages, messageId: testMessageId });
|
||||
const result = await runCreateTodosStep({
|
||||
messages,
|
||||
messageId: testMessageId,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.todos).toBeDefined();
|
||||
|
@ -417,6 +493,7 @@ describe('create-todos-step integration', () => {
|
|||
},
|
||||
],
|
||||
messageId: `${testMessageId}-${i}`,
|
||||
shouldInjectUserPersonalizationTodo: false,
|
||||
})
|
||||
);
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -85,6 +86,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -121,6 +123,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -173,6 +176,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -224,6 +228,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -257,6 +262,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
userPersonalizationMessageContent: '',
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
// input for the workflow
|
||||
|
||||
import type { PermissionedDataset } from '@buster/access-controls';
|
||||
import { messageAnalysisModeEnum } from '@buster/database';
|
||||
import {
|
||||
UserPersonalizationConfigSchema,
|
||||
type UserPersonalizationConfigType,
|
||||
messageAnalysisModeEnum,
|
||||
} from '@buster/database';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
|
@ -52,6 +56,7 @@ const AnalystWorkflowInputSchema = z.object({
|
|||
})
|
||||
)
|
||||
.optional(),
|
||||
userPersonalizationConfig: UserPersonalizationConfigSchema.optional(),
|
||||
});
|
||||
|
||||
export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>;
|
||||
|
@ -62,7 +67,10 @@ export async function runAnalystWorkflow(
|
|||
const workflowStartTime = Date.now();
|
||||
const workflowId = `workflow_${input.chatId}_${input.messageId}`;
|
||||
|
||||
const { messages, analystInstructions, organizationDocs } = input;
|
||||
const { messages, analystInstructions, organizationDocs, userPersonalizationConfig } = input;
|
||||
|
||||
const userPersonalizationMessageContent =
|
||||
generatePersonalizationMessageContent(userPersonalizationConfig);
|
||||
|
||||
const { todos, values, analysisType } = await runAnalystPrepSteps(input);
|
||||
|
||||
|
@ -86,6 +94,7 @@ export async function runAnalystWorkflow(
|
|||
analysisMode: analysisType,
|
||||
analystInstructions,
|
||||
organizationDocs,
|
||||
userPersonalizationMessageContent,
|
||||
},
|
||||
streamOptions: {
|
||||
messages,
|
||||
|
@ -123,6 +132,7 @@ export async function runAnalystWorkflow(
|
|||
workflowStartTime,
|
||||
analystInstructions,
|
||||
organizationDocs,
|
||||
userPersonalizationMessageContent,
|
||||
},
|
||||
streamOptions: {
|
||||
messages,
|
||||
|
@ -232,6 +242,7 @@ const AnalystPrepStepSchema = z.object({
|
|||
chatId: z.string().uuid(),
|
||||
messageId: z.string().uuid(),
|
||||
messageAnalysisMode: z.enum(messageAnalysisModeEnum.enumValues).optional(),
|
||||
userPersonalizationConfig: UserPersonalizationConfigSchema.optional(),
|
||||
});
|
||||
|
||||
type AnalystPrepStepInput = z.infer<typeof AnalystPrepStepSchema>;
|
||||
|
@ -242,15 +253,18 @@ async function runAnalystPrepSteps({
|
|||
chatId,
|
||||
messageId,
|
||||
messageAnalysisMode,
|
||||
userPersonalizationConfig,
|
||||
}: AnalystPrepStepInput): Promise<{
|
||||
todos: CreateTodosResult;
|
||||
values: ExtractValuesSearchResult;
|
||||
analysisType: AnalysisTypeRouterResult['analysisType'];
|
||||
}> {
|
||||
const shouldInjectUserPersonalizationTodo = Boolean(userPersonalizationConfig);
|
||||
const [todos, values, , analysisType] = await Promise.all([
|
||||
runCreateTodosStep({
|
||||
messages,
|
||||
messageId,
|
||||
shouldInjectUserPersonalizationTodo,
|
||||
}),
|
||||
runExtractValuesAndSearchStep({
|
||||
messages,
|
||||
|
@ -269,3 +283,31 @@ async function runAnalystPrepSteps({
|
|||
|
||||
return { todos, values, analysisType: analysisType.analysisType };
|
||||
}
|
||||
|
||||
function generatePersonalizationMessageContent(
|
||||
userPersonalizationConfig: UserPersonalizationConfigType | undefined
|
||||
): string {
|
||||
const userPersonalizationMessageContent: string[] = [];
|
||||
|
||||
if (userPersonalizationConfig) {
|
||||
if (userPersonalizationConfig.currentRole) {
|
||||
userPersonalizationMessageContent.push('<user_current_role>');
|
||||
userPersonalizationMessageContent.push(`${userPersonalizationConfig.currentRole}`);
|
||||
userPersonalizationMessageContent.push('</user_current_role>');
|
||||
}
|
||||
|
||||
if (userPersonalizationConfig.customInstructions) {
|
||||
userPersonalizationMessageContent.push('<custom_instructions>');
|
||||
userPersonalizationMessageContent.push(`${userPersonalizationConfig.customInstructions}`);
|
||||
userPersonalizationMessageContent.push('</custom_instructions>');
|
||||
}
|
||||
|
||||
if (userPersonalizationConfig.additionalInformation) {
|
||||
userPersonalizationMessageContent.push('<additional_information>');
|
||||
userPersonalizationMessageContent.push(`${userPersonalizationConfig.additionalInformation}`);
|
||||
userPersonalizationMessageContent.push('</additional_information>');
|
||||
}
|
||||
}
|
||||
|
||||
return userPersonalizationMessageContent.join('\n');
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
ALTER TABLE "users" ALTER COLUMN "suggested_prompts" SET DEFAULT '{"suggestedPrompts":{"report":["provide a trend analysis of quarterly profits","evaluate product performance across regions"],"dashboard":["create a sales performance dashboard","design a revenue forecast dashboard"],"visualization":["create a metric for monthly sales","show top vendors by purchase volume"],"help":["what types of analyses can you perform?","what questions can I as buster?","what data models are available for queries?","can you explain your forecasting capabilities?"]},"updatedAt":"2025-09-11T23:39:21.533Z"}'::jsonb;--> statement-breakpoint
|
||||
ALTER TABLE "users" ADD COLUMN "personalization_enabled" boolean DEFAULT false NOT NULL;--> statement-breakpoint
|
||||
ALTER TABLE "users" ADD COLUMN "personalization_config" jsonb DEFAULT '{}'::jsonb NOT NULL;
|
File diff suppressed because it is too large
Load Diff
|
@ -666,6 +666,13 @@
|
|||
"when": 1757601195877,
|
||||
"tag": "0094_military_zarek",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 95,
|
||||
"version": "7",
|
||||
"when": 1757633961558,
|
||||
"tag": "0095_quiet_penance",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
import { and, eq } from 'drizzle-orm';
|
||||
import { db } from '../../connection';
|
||||
import { users } from '../../schema';
|
||||
import type { UserPersonalizationConfigType } from '../../schema-types';
|
||||
|
||||
export async function getUserPersonalization(
|
||||
userId: string
|
||||
): Promise<UserPersonalizationConfigType | undefined> {
|
||||
const result = await db
|
||||
.select({
|
||||
personalizationConfig: users.personalizationConfig,
|
||||
})
|
||||
.from(users)
|
||||
.where(and(eq(users.id, userId), eq(users.personalizationEnabled, true)))
|
||||
.limit(1);
|
||||
|
||||
if (result.length === 0 || !result[0]) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const user = result[0];
|
||||
return user.personalizationConfig;
|
||||
}
|
|
@ -4,3 +4,4 @@ export * from './find-user-by-email';
|
|||
export * from './get-user-organizations';
|
||||
export * from './user-queries';
|
||||
export * from './user-suggested-prompts';
|
||||
export * from './get-user-personalization';
|
||||
|
|
|
@ -2,6 +2,7 @@ import { and, eq, isNull } from 'drizzle-orm';
|
|||
import { z } from 'zod';
|
||||
import { db } from '../../connection';
|
||||
import { users, usersToOrganizations } from '../../schema';
|
||||
import { UserPersonalizationConfigSchema } from '../../schema-types';
|
||||
import type { User } from './user';
|
||||
|
||||
// Use the full User type from the schema internally
|
||||
|
@ -14,7 +15,27 @@ export const UserInfoByIdResponseSchema = z.object({
|
|||
role: z.string(),
|
||||
status: z.string(),
|
||||
organizationId: z.string().uuid(),
|
||||
personalizationEnabled: z.boolean(),
|
||||
personalizationConfig: UserPersonalizationConfigSchema,
|
||||
});
|
||||
|
||||
export const UpdateUserInputSchema = z.object({
|
||||
userId: z.string().uuid(),
|
||||
name: z.string().optional(),
|
||||
personalizationEnabled: z.boolean().optional(),
|
||||
personalizationConfig: UserPersonalizationConfigSchema.optional(),
|
||||
});
|
||||
|
||||
export const UpdateUserResponseSchema = z.object({
|
||||
userId: z.string().uuid(),
|
||||
name: z.string().optional(),
|
||||
personalizationEnabled: z.boolean().optional(),
|
||||
personalizationConfig: UserPersonalizationConfigSchema.optional(),
|
||||
updatedAt: z.string().optional(),
|
||||
});
|
||||
|
||||
export type UpdateUserInput = z.infer<typeof UpdateUserInputSchema>;
|
||||
export type UpdateUserResponse = z.infer<typeof UpdateUserResponseSchema>;
|
||||
export type UserInfoByIdResponse = z.infer<typeof UserInfoByIdResponseSchema>;
|
||||
|
||||
/**
|
||||
|
@ -154,6 +175,53 @@ export async function addUserToOrganization(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates user information
|
||||
* @param input The user update parameters
|
||||
* @returns The updated user information
|
||||
*/
|
||||
export async function updateUser(input: UpdateUserInput): Promise<UpdateUserResponse> {
|
||||
const validated = UpdateUserInputSchema.parse(input);
|
||||
|
||||
const updateData: Pick<
|
||||
UpdateUserResponse,
|
||||
'name' | 'personalizationEnabled' | 'personalizationConfig' | 'updatedAt'
|
||||
> = {};
|
||||
|
||||
if (validated.name !== undefined) {
|
||||
updateData.name = validated.name;
|
||||
}
|
||||
|
||||
if (validated.personalizationEnabled !== undefined) {
|
||||
updateData.personalizationEnabled = validated.personalizationEnabled;
|
||||
}
|
||||
|
||||
if (validated.personalizationConfig !== undefined) {
|
||||
updateData.personalizationConfig = validated.personalizationConfig;
|
||||
}
|
||||
|
||||
updateData.updatedAt = new Date().toISOString();
|
||||
|
||||
const result = await db
|
||||
.update(users)
|
||||
.set(updateData)
|
||||
.where(eq(users.id, validated.userId))
|
||||
.returning();
|
||||
|
||||
if (result.length === 0 || !result[0]) {
|
||||
throw new Error(`User not found: ${validated.userId}`);
|
||||
}
|
||||
const updatedUser = result[0];
|
||||
|
||||
return {
|
||||
userId: updatedUser.id,
|
||||
name: updatedUser.name || undefined,
|
||||
personalizationEnabled: updatedUser.personalizationEnabled,
|
||||
personalizationConfig: updatedUser.personalizationConfig,
|
||||
updatedAt: updatedUser.updatedAt,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get comprehensive user information including datasets and permissions
|
||||
* This function replaces the complex Rust implementation with TypeScript
|
||||
|
@ -165,6 +233,8 @@ export async function getUserInformation(userId: string): Promise<UserInfoByIdRe
|
|||
id: users.id,
|
||||
email: users.email,
|
||||
name: users.name,
|
||||
personalizationEnabled: users.personalizationEnabled,
|
||||
personalizationConfig: users.personalizationConfig,
|
||||
role: usersToOrganizations.role,
|
||||
status: usersToOrganizations.status,
|
||||
organizationId: usersToOrganizations.organizationId,
|
||||
|
|
|
@ -10,8 +10,14 @@ export const UserSuggestedPromptsSchema = z.object({
|
|||
updatedAt: z.string(),
|
||||
});
|
||||
|
||||
// User Suggested Prompts Types
|
||||
export const UserPersonalizationConfigSchema = z.object({
|
||||
currentRole: z.string().optional(),
|
||||
customInstructions: z.string().optional(),
|
||||
additionalInformation: z.string().optional(),
|
||||
});
|
||||
|
||||
export type UserSuggestedPromptsType = z.infer<typeof UserSuggestedPromptsSchema>;
|
||||
export type UserPersonalizationConfigType = z.infer<typeof UserPersonalizationConfigSchema>;
|
||||
|
||||
export const DEFAULT_USER_SUGGESTED_PROMPTS: UserSuggestedPromptsType = {
|
||||
suggestedPrompts: {
|
||||
|
|
|
@ -19,7 +19,11 @@ import {
|
|||
uuid,
|
||||
varchar,
|
||||
} from 'drizzle-orm/pg-core';
|
||||
import type { OrganizationColorPalettes, UserSuggestedPromptsType } from './schema-types';
|
||||
import type {
|
||||
OrganizationColorPalettes,
|
||||
UserPersonalizationConfigType,
|
||||
UserSuggestedPromptsType,
|
||||
} from './schema-types';
|
||||
import { DEFAULT_USER_SUGGESTED_PROMPTS } from './schema-types/user';
|
||||
|
||||
export const assetPermissionRoleEnum = pgEnum('asset_permission_role_enum', [
|
||||
|
@ -870,6 +874,11 @@ export const users = pgTable(
|
|||
.$type<UserSuggestedPromptsType>()
|
||||
.default(DEFAULT_USER_SUGGESTED_PROMPTS)
|
||||
.notNull(),
|
||||
personalizationEnabled: boolean('personalization_enabled').default(false).notNull(),
|
||||
personalizationConfig: jsonb('personalization_config')
|
||||
.$type<UserPersonalizationConfigType>()
|
||||
.default({})
|
||||
.notNull(),
|
||||
},
|
||||
(table) => [unique('users_email_key').on(table.email)]
|
||||
);
|
||||
|
|
|
@ -23,7 +23,7 @@ export type {
|
|||
} from './schemas/message-schemas';
|
||||
|
||||
// Export schema-types to use across the codebase
|
||||
export type { UserSuggestedPromptsType } from './schema-types';
|
||||
export type { UserSuggestedPromptsType, UserPersonalizationConfigType } from './schema-types';
|
||||
|
||||
// Export default user suggested prompts
|
||||
export { DEFAULT_USER_SUGGESTED_PROMPTS } from './schema-types/user';
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import type { UserInfoByIdResponse } from '@buster/database';
|
||||
import type { UpdateUserResponse, UserInfoByIdResponse } from '@buster/database';
|
||||
import { UserPersonalizationConfigSchema } from '@buster/database';
|
||||
import { z } from 'zod';
|
||||
import type { UserFavorite } from './favorites.types';
|
||||
import type { UserOrganizationRole } from './roles.types';
|
||||
|
@ -19,11 +20,18 @@ export const UserSchema = z.object({
|
|||
updated_at: z.string(),
|
||||
});
|
||||
|
||||
export type User = z.infer<typeof UserSchema>;
|
||||
export const UserPatchRequestSchema = z.object({
|
||||
name: z.string().optional(),
|
||||
personalizationEnabled: z.boolean().optional(),
|
||||
personalizationConfig: UserPersonalizationConfigSchema.optional(),
|
||||
});
|
||||
|
||||
export const GetUserByIdRequestSchema = z.object({
|
||||
id: z.string().uuid(),
|
||||
});
|
||||
|
||||
export type UserPatchRequest = z.infer<typeof UserPatchRequestSchema>;
|
||||
export type UserPatchResponse = UpdateUserResponse;
|
||||
export type User = z.infer<typeof UserSchema>;
|
||||
export type GetUserByIdRequest = z.infer<typeof GetUserByIdRequestSchema>;
|
||||
export type GetUserByIdResponse = UserInfoByIdResponse;
|
||||
|
|
Loading…
Reference in New Issue