Skip to content

Commit

Permalink
Run embeddings and symf retrieval in parallel and implement basic fus…
Browse files Browse the repository at this point in the history
…ion (#2804)

This PR is changing the simple chat context engine to run the symf
search together with any dense retrieval strategies (local/remote
embedding) in parallel and fuse the results. The fusion is currently
done naively by allocating up to 80% of the context window for the symf
results and the remaining 20% for embeddings.

We will follow this up with a more advanced fusion logic that will also
address some issues for Autocomplete. We, however, need to adjust the
logic we have for RRF to have a finer granularity before we can do this.

## Test plan

- Added unit test for the fusion code
- Ask questions and look at the context items being picked up - Most
importantly it passes the squirrel test.
  • Loading branch information
philipp-spiess authored Jan 25, 2024
1 parent 80caa98 commit 646d0b7
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 85 deletions.
32 changes: 18 additions & 14 deletions vscode/src/chat/chat-view/SimpleChatPanelProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ export interface ChatSession {
webviewPanel?: vscode.WebviewPanel
sessionID: string
}

/**
* SimpleChatPanelProvider is the view controller class for the chat panel.
* It handles all events sent from the view, keeps track of the underlying chat model,
Expand Down Expand Up @@ -404,16 +403,20 @@ export class SimpleChatPanelProvider implements vscode.Disposable, ChatSession {
const prompter = new DefaultPrompter(
userContextItems,
addEnhancedContext
? query =>
getEnhancedContext(
this.config.useContext,
this.editor,
this.embeddingsClient,
this.localEmbeddings,
this.config.experimentalSymfContext ? this.symf : null,
this.codebaseStatusProvider,
query
)
? (text, maxChars) =>
getEnhancedContext({
strategy: this.config.useContext,
editor: this.editor,
text,
providers: {
embeddingsClient: this.embeddingsClient,
localEmbeddings: this.localEmbeddings,
symf: this.config.experimentalSymfContext ? this.symf : null,
codebaseStatusProvider: this.codebaseStatusProvider,
},
featureFlags: this.config,
hints: { maxChars },
})
: undefined
)
const sendTelemetry = (contextSummary: any): void => {
Expand Down Expand Up @@ -765,10 +768,11 @@ export class SimpleChatPanelProvider implements vscode.Disposable, ChatSession {
prompter: IPrompter,
sendTelemetry?: (contextSummary: any) => void
): Promise<Message[]> {
const { prompt, newContextUsed } = await prompter.makePrompt(
this.chatModel,
getContextWindowForModel(this.authProvider.getAuthStatus(), this.chatModel.modelID)
const maxChars = getContextWindowForModel(
this.authProvider.getAuthStatus(),
this.chatModel.modelID
)
const { prompt, newContextUsed } = await prompter.makePrompt(this.chatModel, maxChars)

// Update UI based on prompt construction
this.chatModel.setNewContextUsed(newContextUsed)
Expand Down
71 changes: 71 additions & 0 deletions vscode/src/chat/chat-view/context.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { describe, expect, it } from 'vitest'
import { fuseContext } from './context'
import { testFileUri } from '@sourcegraph/cody-shared'
import type { ContextItem } from './SimpleChatModel'

describe('fuseContext', () => {
const uri = testFileUri('test.ts')
const keywordItems = [
{ text: '0', uri },
{ text: '1', uri },
{ text: '2', uri },
{ text: '3', uri },
{ text: '4', uri },
{ text: '5', uri },
{ text: '6', uri },
{ text: '7', uri },
{ text: '8', uri },
{ text: '9', uri },
]
const embeddingsItems = [
{ text: 'A', uri },
{ text: 'B', uri },
{ text: 'C', uri },
]

function joined(items: ContextItem[]): string {
return items.map(r => r.text).join('')
}

it('includes the right 80-20 split', () => {
const maxChars = 10
const result = fuseContext(keywordItems, embeddingsItems, maxChars)
expect(joined(result)).toEqual('01234567AB')
})

it('skips over large items in an attempt to optimize utilization', () => {
const keywordItems = [
{ text: '0', uri },
{ text: '1', uri },
{ text: '2', uri },
{ text: '3', uri },
{ text: '4', uri },
{ text: '5', uri },
{ text: 'very large keyword item', uri },
{ text: '6', uri },
{ text: '7', uri },
{ text: '8', uri },
{ text: '9', uri },
]
const embeddingsItems = [
{ text: 'A', uri },
{ text: 'very large embeddings item', uri },
{ text: 'B', uri },
{ text: 'C', uri },
]
const maxChars = 10
const result = fuseContext(keywordItems, embeddingsItems, maxChars)
expect(joined(result)).toEqual('01234567AB')
})

it('returns an empty array when maxChars is 0', () => {
const result = fuseContext(keywordItems, embeddingsItems, 0)
expect(result).toEqual([])
})

it('includes all keyword items if there are no embeddings items', () => {
const maxChars = 10
const result = fuseContext(keywordItems, [], maxChars)
expect(joined(result)).toEqual('0123456789')
})
})
193 changes: 154 additions & 39 deletions vscode/src/chat/chat-view/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,59 @@ import type { ContextItem } from './SimpleChatModel'

const isAgentTesting = process.env.CODY_SHIM_TESTING === 'true'

export async function getEnhancedContext(
useContextConfig: ConfigurationUseContext,
editor: VSCodeEditor,
embeddingsClient: CachedRemoteEmbeddingsClient,
localEmbeddings: LocalEmbeddingsController | null,
symf: SymfRunner | null,
codebaseStatusProvider: CodebaseStatusProvider,
export interface GetEnhancedContextOptions {
strategy: ConfigurationUseContext
editor: VSCodeEditor
text: string
): Promise<ContextItem[]> {
providers: {
codebaseStatusProvider: CodebaseStatusProvider
embeddingsClient: CachedRemoteEmbeddingsClient
localEmbeddings: LocalEmbeddingsController | null
symf: SymfRunner | null
}
featureFlags: {
internalUnstable: boolean
}
hints: {
maxChars: number
}
// TODO(@philipp-spiess): Add abort controller to be able to cancel expensive retrievers
}
export async function getEnhancedContext({
strategy,
editor,
text,
providers,
featureFlags,
hints,
}: GetEnhancedContextOptions): Promise<ContextItem[]> {
if (featureFlags.internalUnstable) {
return getEnhancedContextFused({
strategy,
editor,
text,
providers,
featureFlags,
hints,
})
}
const searchContext: ContextItem[] = []

// use user attention context only if config is set to none
if (useContextConfig === 'none') {
if (strategy === 'none') {
logDebug('SimpleChatPanelProvider', 'getEnhancedContext > none')
searchContext.push(...getVisibleEditorContext(editor))
return searchContext
}

let hasEmbeddingsContext = false
// Get embeddings context if useContext Config is not set to 'keyword' only
if (useContextConfig !== 'keyword') {
if (strategy !== 'keyword') {
logDebug('SimpleChatPanelProvider', 'getEnhancedContext > embeddings (start)')
const localEmbeddingsResults = searchEmbeddingsLocal(localEmbeddings, text)
const localEmbeddingsResults = searchEmbeddingsLocal(providers.localEmbeddings, text)
const remoteEmbeddingsResults = searchEmbeddingsRemote(
embeddingsClient,
codebaseStatusProvider,
providers.embeddingsClient,
providers.codebaseStatusProvider,
text
)
try {
Expand All @@ -73,42 +100,54 @@ export async function getEnhancedContext(
}

// Fallback to symf if embeddings provided no results or if useContext is set to 'keyword' specifically
if (!hasEmbeddingsContext && symf) {
if (!hasEmbeddingsContext && providers.symf) {
logDebug('SimpleChatPanelProvider', 'getEnhancedContext > search')
try {
searchContext.push(...(await searchSymf(symf, editor, text)))
searchContext.push(...(await searchSymf(providers.symf, editor, text)))
} catch (error) {
// TODO(beyang): handle this error better
logDebug('SimpleChatPanelProvider.getEnhancedContext', 'searchSymf error', error)
}
}

const priorityContext: ContextItem[] = []
const selectionContext = getCurrentSelectionContext(editor)
if (selectionContext.length > 0) {
priorityContext.push(...selectionContext)
} else if (needsUserAttentionContext(text)) {
// Query refers to current editor
priorityContext.push(...getVisibleEditorContext(editor))
} else if (needsReadmeContext(editor, text)) {
// Query refers to project, so include the README
let containsREADME = false
for (const contextItem of searchContext) {
const basename = uriBasename(contextItem.uri)
if (
basename.toLocaleLowerCase() === 'readme' ||
basename.toLocaleLowerCase().startsWith('readme.')
) {
containsREADME = true
break
}
}
if (!containsREADME) {
priorityContext.push(...(await getReadmeContext()))
}
const priorityContext = await getPriorityContext(text, editor, searchContext)
return priorityContext.concat(searchContext)
}

async function getEnhancedContextFused({
strategy,
editor,
text,
providers,
hints,
}: GetEnhancedContextOptions): Promise<ContextItem[]> {
// use user attention context only if config is set to none
if (strategy === 'none') {
logDebug('SimpleChatPanelProvider', 'getEnhancedContext > none')
return getVisibleEditorContext(editor)
}

return priorityContext.concat(searchContext)
// Get embeddings context if useContext Config is not set to 'keyword' only
const keywordContextItemsPromise =
strategy !== 'keyword'
? retrieveContextGracefully(
searchEmbeddingsLocal(providers.localEmbeddings, text),
'local-embeddings'
)
: []
const searchContextItemsPromise = providers.symf
? retrieveContextGracefully(searchSymf(providers.symf, editor, text), 'symf')
: []

const [keywordContextItems, searchContextItems] = await Promise.all([
keywordContextItemsPromise,
searchContextItemsPromise,
])

const fusedContext = fuseContext(keywordContextItems, searchContextItems, hints.maxChars)

const priorityContext = await getPriorityContext(text, editor, fusedContext)
return priorityContext.concat(fusedContext)
}

/**
Expand Down Expand Up @@ -341,6 +380,38 @@ function getVisibleEditorContext(editor: VSCodeEditor): ContextItem[] {
]
}

async function getPriorityContext(
text: string,
editor: VSCodeEditor,
retrievedContext: ContextItem[]
): Promise<ContextItem[]> {
const priorityContext: ContextItem[] = []
const selectionContext = getCurrentSelectionContext(editor)
if (selectionContext.length > 0) {
priorityContext.push(...selectionContext)
} else if (needsUserAttentionContext(text)) {
// Query refers to current editor
priorityContext.push(...getVisibleEditorContext(editor))
} else if (needsReadmeContext(editor, text)) {
// Query refers to project, so include the README
let containsREADME = false
for (const contextItem of retrievedContext) {
const basename = uriBasename(contextItem.uri)
if (
basename.toLocaleLowerCase() === 'readme' ||
basename.toLocaleLowerCase().startsWith('readme.')
) {
containsREADME = true
break
}
}
if (!containsREADME) {
priorityContext.push(...(await getReadmeContext()))
}
}
return priorityContext
}

function needsUserAttentionContext(input: string): boolean {
const inputLowerCase = input.toLowerCase()
// If the input matches any of the `editorRegexps` we assume that we have to include
Expand Down Expand Up @@ -464,3 +535,47 @@ function extractQuestion(input: string): string | undefined {
}
return undefined
}

async function retrieveContextGracefully<T>(promise: Promise<T[]>, strategy: string): Promise<T[]> {
try {
logDebug('SimpleChatPanelProvider', `getEnhancedContext > ${strategy} (start)`)
return await promise
} catch (error) {
logError('SimpleChatPanelProvider', `getEnhancedContext > ${strategy}' (error)`, error)
return []
} finally {
logDebug('SimpleChatPanelProvider', `getEnhancedContext > ${strategy} (end)`)
}
}

// A simple context fusion engine that picks the top most keyword results to fill up 80% of the
// context window and picks the top ranking embeddings items for the remainder.
export function fuseContext(
keywordItems: ContextItem[],
embeddingsItems: ContextItem[],
maxChars: number
): ContextItem[] {
let charsUsed = 0
const fused = []
const maxKeywordChars = embeddingsItems.length > 0 ? maxChars * 0.8 : maxChars

for (const item of keywordItems) {
const len = item.text.length

if (charsUsed + len <= maxKeywordChars) {
charsUsed += len
fused.push(item)
}
}

for (const item of embeddingsItems) {
const len = item.text.length

if (charsUsed + len <= maxChars) {
charsUsed += len
fused.push(item)
}
}

return fused
}
Loading

0 comments on commit 646d0b7

Please sign in to comment.