diff --git a/vscode/src/chat/chat-view/SimpleChatPanelProvider.ts b/vscode/src/chat/chat-view/SimpleChatPanelProvider.ts index 3a054074a4a..18184e669eb 100644 --- a/vscode/src/chat/chat-view/SimpleChatPanelProvider.ts +++ b/vscode/src/chat/chat-view/SimpleChatPanelProvider.ts @@ -105,7 +105,6 @@ export interface ChatSession { webviewPanel?: vscode.WebviewPanel sessionID: string } - /** * SimpleChatPanelProvider is the view controller class for the chat panel. * It handles all events sent from the view, keeps track of the underlying chat model, @@ -404,16 +403,20 @@ export class SimpleChatPanelProvider implements vscode.Disposable, ChatSession { const prompter = new DefaultPrompter( userContextItems, addEnhancedContext - ? query => - getEnhancedContext( - this.config.useContext, - this.editor, - this.embeddingsClient, - this.localEmbeddings, - this.config.experimentalSymfContext ? this.symf : null, - this.codebaseStatusProvider, - query - ) + ? (text, maxChars) => + getEnhancedContext({ + strategy: this.config.useContext, + editor: this.editor, + text, + providers: { + embeddingsClient: this.embeddingsClient, + localEmbeddings: this.localEmbeddings, + symf: this.config.experimentalSymfContext ? this.symf : null, + codebaseStatusProvider: this.codebaseStatusProvider, + }, + featureFlags: this.config, + hints: { maxChars }, + }) : undefined ) const sendTelemetry = (contextSummary: any): void => { @@ -765,10 +768,11 @@ export class SimpleChatPanelProvider implements vscode.Disposable, ChatSession { prompter: IPrompter, sendTelemetry?: (contextSummary: any) => void ): Promise { - const { prompt, newContextUsed } = await prompter.makePrompt( - this.chatModel, - getContextWindowForModel(this.authProvider.getAuthStatus(), this.chatModel.modelID) + const maxChars = getContextWindowForModel( + this.authProvider.getAuthStatus(), + this.chatModel.modelID ) + const { prompt, newContextUsed } = await prompter.makePrompt(this.chatModel, maxChars) // Update UI based on prompt construction this.chatModel.setNewContextUsed(newContextUsed) diff --git a/vscode/src/chat/chat-view/context.test.ts b/vscode/src/chat/chat-view/context.test.ts new file mode 100644 index 00000000000..7dc83b14b99 --- /dev/null +++ b/vscode/src/chat/chat-view/context.test.ts @@ -0,0 +1,71 @@ +import { describe, expect, it } from 'vitest' +import { fuseContext } from './context' +import { testFileUri } from '@sourcegraph/cody-shared' +import type { ContextItem } from './SimpleChatModel' + +describe('fuseContext', () => { + const uri = testFileUri('test.ts') + const keywordItems = [ + { text: '0', uri }, + { text: '1', uri }, + { text: '2', uri }, + { text: '3', uri }, + { text: '4', uri }, + { text: '5', uri }, + { text: '6', uri }, + { text: '7', uri }, + { text: '8', uri }, + { text: '9', uri }, + ] + const embeddingsItems = [ + { text: 'A', uri }, + { text: 'B', uri }, + { text: 'C', uri }, + ] + + function joined(items: ContextItem[]): string { + return items.map(r => r.text).join('') + } + + it('includes the right 80-20 split', () => { + const maxChars = 10 + const result = fuseContext(keywordItems, embeddingsItems, maxChars) + expect(joined(result)).toEqual('01234567AB') + }) + + it('skips over large items in an attempt to optimize utilization', () => { + const keywordItems = [ + { text: '0', uri }, + { text: '1', uri }, + { text: '2', uri }, + { text: '3', uri }, + { text: '4', uri }, + { text: '5', uri }, + { text: 'very large keyword item', uri }, + { text: '6', uri }, + { text: '7', uri }, + { text: '8', uri }, + { text: '9', uri }, + ] + const embeddingsItems = [ + { text: 'A', uri }, + { text: 'very large embeddings item', uri }, + { text: 'B', uri }, + { text: 'C', uri }, + ] + const maxChars = 10 + const result = fuseContext(keywordItems, embeddingsItems, maxChars) + expect(joined(result)).toEqual('01234567AB') + }) + + it('returns an empty array when maxChars is 0', () => { + const result = fuseContext(keywordItems, embeddingsItems, 0) + expect(result).toEqual([]) + }) + + it('includes all keyword items if there are no embeddings items', () => { + const maxChars = 10 + const result = fuseContext(keywordItems, [], maxChars) + expect(joined(result)).toEqual('0123456789') + }) +}) diff --git a/vscode/src/chat/chat-view/context.ts b/vscode/src/chat/chat-view/context.ts index b9e1af37d7f..2f9489302e7 100644 --- a/vscode/src/chat/chat-view/context.ts +++ b/vscode/src/chat/chat-view/context.ts @@ -27,19 +27,46 @@ import type { ContextItem } from './SimpleChatModel' const isAgentTesting = process.env.CODY_SHIM_TESTING === 'true' -export async function getEnhancedContext( - useContextConfig: ConfigurationUseContext, - editor: VSCodeEditor, - embeddingsClient: CachedRemoteEmbeddingsClient, - localEmbeddings: LocalEmbeddingsController | null, - symf: SymfRunner | null, - codebaseStatusProvider: CodebaseStatusProvider, +export interface GetEnhancedContextOptions { + strategy: ConfigurationUseContext + editor: VSCodeEditor text: string -): Promise { + providers: { + codebaseStatusProvider: CodebaseStatusProvider + embeddingsClient: CachedRemoteEmbeddingsClient + localEmbeddings: LocalEmbeddingsController | null + symf: SymfRunner | null + } + featureFlags: { + internalUnstable: boolean + } + hints: { + maxChars: number + } + // TODO(@philipp-spiess): Add abort controller to be able to cancel expensive retrievers +} +export async function getEnhancedContext({ + strategy, + editor, + text, + providers, + featureFlags, + hints, +}: GetEnhancedContextOptions): Promise { + if (featureFlags.internalUnstable) { + return getEnhancedContextFused({ + strategy, + editor, + text, + providers, + featureFlags, + hints, + }) + } const searchContext: ContextItem[] = [] // use user attention context only if config is set to none - if (useContextConfig === 'none') { + if (strategy === 'none') { logDebug('SimpleChatPanelProvider', 'getEnhancedContext > none') searchContext.push(...getVisibleEditorContext(editor)) return searchContext @@ -47,12 +74,12 @@ export async function getEnhancedContext( let hasEmbeddingsContext = false // Get embeddings context if useContext Config is not set to 'keyword' only - if (useContextConfig !== 'keyword') { + if (strategy !== 'keyword') { logDebug('SimpleChatPanelProvider', 'getEnhancedContext > embeddings (start)') - const localEmbeddingsResults = searchEmbeddingsLocal(localEmbeddings, text) + const localEmbeddingsResults = searchEmbeddingsLocal(providers.localEmbeddings, text) const remoteEmbeddingsResults = searchEmbeddingsRemote( - embeddingsClient, - codebaseStatusProvider, + providers.embeddingsClient, + providers.codebaseStatusProvider, text ) try { @@ -73,42 +100,54 @@ export async function getEnhancedContext( } // Fallback to symf if embeddings provided no results or if useContext is set to 'keyword' specifically - if (!hasEmbeddingsContext && symf) { + if (!hasEmbeddingsContext && providers.symf) { logDebug('SimpleChatPanelProvider', 'getEnhancedContext > search') try { - searchContext.push(...(await searchSymf(symf, editor, text))) + searchContext.push(...(await searchSymf(providers.symf, editor, text))) } catch (error) { // TODO(beyang): handle this error better logDebug('SimpleChatPanelProvider.getEnhancedContext', 'searchSymf error', error) } } - const priorityContext: ContextItem[] = [] - const selectionContext = getCurrentSelectionContext(editor) - if (selectionContext.length > 0) { - priorityContext.push(...selectionContext) - } else if (needsUserAttentionContext(text)) { - // Query refers to current editor - priorityContext.push(...getVisibleEditorContext(editor)) - } else if (needsReadmeContext(editor, text)) { - // Query refers to project, so include the README - let containsREADME = false - for (const contextItem of searchContext) { - const basename = uriBasename(contextItem.uri) - if ( - basename.toLocaleLowerCase() === 'readme' || - basename.toLocaleLowerCase().startsWith('readme.') - ) { - containsREADME = true - break - } - } - if (!containsREADME) { - priorityContext.push(...(await getReadmeContext())) - } + const priorityContext = await getPriorityContext(text, editor, searchContext) + return priorityContext.concat(searchContext) +} + +async function getEnhancedContextFused({ + strategy, + editor, + text, + providers, + hints, +}: GetEnhancedContextOptions): Promise { + // use user attention context only if config is set to none + if (strategy === 'none') { + logDebug('SimpleChatPanelProvider', 'getEnhancedContext > none') + return getVisibleEditorContext(editor) } - return priorityContext.concat(searchContext) + // Get embeddings context if useContext Config is not set to 'keyword' only + const keywordContextItemsPromise = + strategy !== 'keyword' + ? retrieveContextGracefully( + searchEmbeddingsLocal(providers.localEmbeddings, text), + 'local-embeddings' + ) + : [] + const searchContextItemsPromise = providers.symf + ? retrieveContextGracefully(searchSymf(providers.symf, editor, text), 'symf') + : [] + + const [keywordContextItems, searchContextItems] = await Promise.all([ + keywordContextItemsPromise, + searchContextItemsPromise, + ]) + + const fusedContext = fuseContext(keywordContextItems, searchContextItems, hints.maxChars) + + const priorityContext = await getPriorityContext(text, editor, fusedContext) + return priorityContext.concat(fusedContext) } /** @@ -341,6 +380,38 @@ function getVisibleEditorContext(editor: VSCodeEditor): ContextItem[] { ] } +async function getPriorityContext( + text: string, + editor: VSCodeEditor, + retrievedContext: ContextItem[] +): Promise { + const priorityContext: ContextItem[] = [] + const selectionContext = getCurrentSelectionContext(editor) + if (selectionContext.length > 0) { + priorityContext.push(...selectionContext) + } else if (needsUserAttentionContext(text)) { + // Query refers to current editor + priorityContext.push(...getVisibleEditorContext(editor)) + } else if (needsReadmeContext(editor, text)) { + // Query refers to project, so include the README + let containsREADME = false + for (const contextItem of retrievedContext) { + const basename = uriBasename(contextItem.uri) + if ( + basename.toLocaleLowerCase() === 'readme' || + basename.toLocaleLowerCase().startsWith('readme.') + ) { + containsREADME = true + break + } + } + if (!containsREADME) { + priorityContext.push(...(await getReadmeContext())) + } + } + return priorityContext +} + function needsUserAttentionContext(input: string): boolean { const inputLowerCase = input.toLowerCase() // If the input matches any of the `editorRegexps` we assume that we have to include @@ -464,3 +535,47 @@ function extractQuestion(input: string): string | undefined { } return undefined } + +async function retrieveContextGracefully(promise: Promise, strategy: string): Promise { + try { + logDebug('SimpleChatPanelProvider', `getEnhancedContext > ${strategy} (start)`) + return await promise + } catch (error) { + logError('SimpleChatPanelProvider', `getEnhancedContext > ${strategy}' (error)`, error) + return [] + } finally { + logDebug('SimpleChatPanelProvider', `getEnhancedContext > ${strategy} (end)`) + } +} + +// A simple context fusion engine that picks the top most keyword results to fill up 80% of the +// context window and picks the top ranking embeddings items for the remainder. +export function fuseContext( + keywordItems: ContextItem[], + embeddingsItems: ContextItem[], + maxChars: number +): ContextItem[] { + let charsUsed = 0 + const fused = [] + const maxKeywordChars = embeddingsItems.length > 0 ? maxChars * 0.8 : maxChars + + for (const item of keywordItems) { + const len = item.text.length + + if (charsUsed + len <= maxKeywordChars) { + charsUsed += len + fused.push(item) + } + } + + for (const item of embeddingsItems) { + const len = item.text.length + + if (charsUsed + len <= maxChars) { + charsUsed += len + fused.push(item) + } + } + + return fused +} diff --git a/vscode/src/chat/chat-view/prompt.ts b/vscode/src/chat/chat-view/prompt.ts index dc44a48b1ba..c050c0ffa24 100644 --- a/vscode/src/chat/chat-view/prompt.ts +++ b/vscode/src/chat/chat-view/prompt.ts @@ -21,24 +21,22 @@ import { type SimpleChatModel, } from './SimpleChatModel' -export interface IContextProvider { - // Relevant context pulled from the editor state and broader repository - getEnhancedContext(query: string): Promise -} - interface PromptInfo { prompt: Message[] newContextUsed: ContextItem[] } export interface IPrompter { - makePrompt(chat: SimpleChatModel, byteLimit: number): Promise + makePrompt(chat: SimpleChatModel, charLimit: number): Promise } +const ENHANCED_CONTEXT_ALLOCATION = 0.6 // Enhanced context should take up 60% of the context window + export class CommandPrompter implements IPrompter { - constructor(private getContextItems: () => Promise) {} - public async makePrompt(chat: SimpleChatModel, byteLimit: number): Promise { - const promptBuilder = new PromptBuilder(byteLimit) + constructor(private getContextItems: (maxChars: number) => Promise) {} + public async makePrompt(chat: SimpleChatModel, charLimit: number): Promise { + const enhancedContextCharLimit = Math.floor(charLimit * ENHANCED_CONTEXT_ALLOCATION) + const promptBuilder = new PromptBuilder(charLimit) const newContextUsed: ContextItem[] = [] const preInstruction: string | undefined = vscode.workspace .getConfiguration('cody.chat') @@ -47,7 +45,7 @@ export class CommandPrompter implements IPrompter { const preambleMessages = getSimplePreamble(preInstruction) const preambleSucceeded = promptBuilder.tryAddToPrefix(preambleMessages) if (!preambleSucceeded) { - throw new Error(`Preamble length exceeded context window size ${byteLimit}`) + throw new Error(`Preamble length exceeded context window size ${charLimit}`) } // Add existing transcript messages @@ -67,10 +65,10 @@ export class CommandPrompter implements IPrompter { } } - const contextItems = await this.getContextItems() + const contextItems = await this.getContextItems(enhancedContextCharLimit) const { limitReached, used, ignored } = promptBuilder.tryAddContext( contextItems, - Math.floor(byteLimit * 0.6) // Allocate no more than 60% of context window to enhanced context + enhancedContextCharLimit ) newContextUsed.push(...used) if (limitReached) { @@ -93,7 +91,7 @@ export class CommandPrompter implements IPrompter { export class DefaultPrompter implements IPrompter { constructor( private explicitContext: ContextItem[], - private getEnhancedContext?: (query: string) => Promise + private getEnhancedContext?: (query: string, charLimit: number) => Promise ) {} // Constructs the raw prompt to send to the LLM, with message order reversed, so we can construct // an array with the most important messages (which appear most important first in the reverse-prompt. @@ -102,12 +100,13 @@ export class DefaultPrompter implements IPrompter { // prompt for the current message. public async makePrompt( chat: SimpleChatModel, - byteLimit: number + charLimit: number ): Promise<{ prompt: Message[] newContextUsed: ContextItem[] }> { - const promptBuilder = new PromptBuilder(byteLimit) + const enhancedContextCharLimit = Math.floor(charLimit * ENHANCED_CONTEXT_ALLOCATION) + const promptBuilder = new PromptBuilder(charLimit) const newContextUsed: ContextItem[] = [] const preInstruction: string | undefined = vscode.workspace .getConfiguration('cody.chat') @@ -116,7 +115,7 @@ export class DefaultPrompter implements IPrompter { const preambleMessages = getSimplePreamble(preInstruction) const preambleSucceeded = promptBuilder.tryAddToPrefix(preambleMessages) if (!preambleSucceeded) { - throw new Error(`Preamble length exceeded context window size ${byteLimit}`) + throw new Error(`Preamble length exceeded context window size ${charLimit}`) } // Add existing transcript messages @@ -174,10 +173,13 @@ export class DefaultPrompter implements IPrompter { } if (this.getEnhancedContext) { // Add additional context from current editor or broader search - const additionalContextItems = await this.getEnhancedContext(lastMessage.message.text) + const additionalContextItems = await this.getEnhancedContext( + lastMessage.message.text, + enhancedContextCharLimit + ) const { limitReached, used, ignored } = promptBuilder.tryAddContext( additionalContextItems, - Math.floor(byteLimit * 0.6) // Allocate no more than 60% of context window to enhanced context + enhancedContextCharLimit ) newContextUsed.push(...used) if (limitReached) { @@ -222,7 +224,7 @@ function renderContextItem(contextItem: ContextItem): Message[] { } /** - * PromptBuilder constructs a full prompt given a byteLimit constraint. + * PromptBuilder constructs a full prompt given a charLimit constraint. * The final prompt is constructed by concatenating the following fields: * - prefixMessages * - the reverse of reverseMessages @@ -230,24 +232,24 @@ function renderContextItem(contextItem: ContextItem): Message[] { class PromptBuilder { private prefixMessages: Message[] = [] private reverseMessages: Message[] = [] - private bytesUsed = 0 + private charsUsed = 0 private seenContext = new Set() - constructor(private readonly byteLimit: number) {} + constructor(private readonly charLimit: number) {} public build(): Message[] { return this.prefixMessages.concat([...this.reverseMessages].reverse()) } public tryAddToPrefix(messages: Message[]): boolean { - let numBytes = 0 + let numChars = 0 for (const message of messages) { - numBytes += message.speaker.length + (message.text?.length || 0) + 3 // space and 2 newlines + numChars += message.speaker.length + (message.text?.length || 0) + 3 // space and 2 newlines } - if (numBytes + this.bytesUsed > this.byteLimit) { + if (numChars + this.charsUsed > this.charLimit) { return false } this.prefixMessages.push(...messages) - this.bytesUsed += numBytes + this.charsUsed += numChars return true } @@ -258,28 +260,28 @@ class PromptBuilder { } const msgLen = message.speaker.length + (message.text?.length || 0) + 3 // space and 2 newlines - if (this.bytesUsed + msgLen > this.byteLimit) { + if (this.charsUsed + msgLen > this.charLimit) { return false } this.reverseMessages.push(message) - this.bytesUsed += msgLen + this.charsUsed += msgLen return true } /** - * Tries to add context items to the prompt, tracking bytes used. + * Tries to add context items to the prompt, tracking characters used. * Returns info about which items were used vs. ignored. */ public tryAddContext( contextItems: ContextItem[], - byteLimit?: number + charLimit?: number ): { limitReached: boolean used: ContextItem[] ignored: ContextItem[] duplicate: ContextItem[] } { - const effectiveByteLimit = byteLimit ? this.bytesUsed + byteLimit : this.byteLimit + const effectiveCharLimit = charLimit ? this.charsUsed + charLimit : this.charLimit let limitReached = false const used: ContextItem[] = [] const ignored: ContextItem[] = [] @@ -299,14 +301,14 @@ class PromptBuilder { (acc, msg) => acc + msg.speaker.length + (msg.text?.length || 0) + 3, 0 ) - if (this.bytesUsed + contextLen > effectiveByteLimit) { + if (this.charsUsed + contextLen > effectiveCharLimit) { ignored.push(contextItem) limitReached = true continue } this.seenContext.add(id) this.reverseMessages.push(...contextMessages) - this.bytesUsed += contextLen + this.charsUsed += contextLen used.push(contextItem) } return {