Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce blocking HTTP requests during authentication #5799

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ data class AuthStatus(
val isFireworksTracingEnabled: Boolean? = null,
val hasVerifiedEmail: Boolean? = null,
val requiresVerifiedEmail: Boolean? = null,
val siteVersion: String,
val codyApiVersion: Long,
val configOverwrites: CodyLLMSiteConfiguration? = null,
val primaryEmail: String? = null,
val displayName: String? = null,
val avatarURL: String? = null,
val userCanUpgrade: Boolean? = null,
)

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package com.sourcegraph.cody.agent.protocol_generated;
data class ServerInfo(
val name: String,
val authenticated: Boolean? = null,
val codyVersion: String? = null,
val authStatus: AuthStatus? = null,
)

1 change: 0 additions & 1 deletion agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ export class Agent extends MessageHandler implements ExtensionClient {
return {
name: 'cody-agent',
authenticated: authStatus?.authenticated ?? false,
codyVersion: authStatus?.authenticated ? authStatus.siteVersion : undefined,
authStatus,
}
} catch (error) {
Expand Down
2 changes: 1 addition & 1 deletion lib/shared/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"@types/isomorphic-fetch": "^0.0.39",
"@types/lodash": "^4.14.195",
"@types/node-fetch": "^2.6.4",
"@types/semver": "^7.5.0",
"@types/semver": "^7.5.8",
"@types/vscode": "^1.80.0",
"type-fest": "^4.26.1"
}
Expand Down
28 changes: 6 additions & 22 deletions lib/shared/src/auth/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { isDotCom } from '../sourcegraph-api/environments'
import type { CodyLLMSiteConfiguration } from '../sourcegraph-api/graphql/client'
import type { UserProductSubscription } from '../sourcegraph-api/userProductSubscription'

/**
* The authentication status, which includes representing the state when authentication failed or
Expand All @@ -25,21 +25,11 @@ export interface AuthenticatedAuthStatus {

hasVerifiedEmail?: boolean
requiresVerifiedEmail?: boolean
siteVersion: string
codyApiVersion: number
configOverwrites?: CodyLLMSiteConfiguration

primaryEmail?: string
displayName?: string
avatarURL?: string
/**
* Whether the users account can be upgraded.
*
* This is `true` if the user is on dotCom and has not already upgraded. It
* is used to customize rate limit messages and show additional upgrade
* buttons in the UI.
*/
userCanUpgrade?: boolean

pendingValidation: boolean
}

Expand All @@ -59,8 +49,6 @@ export const AUTH_STATUS_FIXTURE_AUTHED: AuthenticatedAuthStatus = {
endpoint: 'https://example.com',
authenticated: true,
username: 'alice',
codyApiVersion: 1,
siteVersion: '9999',
pendingValidation: false,
}

Expand All @@ -73,18 +61,14 @@ export const AUTH_STATUS_FIXTURE_UNAUTHED: AuthStatus & { authenticated: false }
export const AUTH_STATUS_FIXTURE_AUTHED_DOTCOM: AuthenticatedAuthStatus = {
...AUTH_STATUS_FIXTURE_AUTHED,
endpoint: 'https://sourcegraph.com',
configOverwrites: {
provider: 'sourcegraph',
completionModel: 'fireworks/starcoder-hybrid',
},
}

export function isCodyProUser(authStatus: AuthStatus): boolean {
return isDotCom(authStatus) && authStatus.authenticated && !authStatus.userCanUpgrade
export function isCodyProUser(authStatus: AuthStatus, sub: UserProductSubscription | null): boolean {
return isDotCom(authStatus) && authStatus.authenticated && sub !== null && !sub.userCanUpgrade
}

export function isFreeUser(authStatus: AuthStatus): boolean {
return isDotCom(authStatus) && authStatus.authenticated && !!authStatus.userCanUpgrade
export function isFreeUser(authStatus: AuthStatus, sub: UserProductSubscription | null): boolean {
return isDotCom(authStatus) && authStatus.authenticated && sub !== null && !!sub.userCanUpgrade
}

export function isEnterpriseUser(authStatus: AuthStatus): boolean {
Expand Down
41 changes: 22 additions & 19 deletions lib/shared/src/chat/chat.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import type { AuthenticatedAuthStatus } from '../auth/types'
import { authStatus } from '../auth/authStatus'
import { firstValueFrom } from '../misc/observable'
import type { Message } from '../sourcegraph-api'
import type { SourcegraphCompletionsClient } from '../sourcegraph-api/completions/client'
import type {
CompletionGeneratorValue,
CompletionParameters,
} from '../sourcegraph-api/completions/types'
import { currentSiteVersion } from '../sourcegraph-api/siteVersion'

type ChatParameters = Omit<CompletionParameters, 'messages'>

Expand All @@ -15,25 +17,25 @@ const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<ChatParameters, 'maxTokensToSampl
}

export class ChatClient {
constructor(
private completions: SourcegraphCompletionsClient,
private getAuthStatus: () => Pick<
AuthenticatedAuthStatus,
| 'authenticated'
| 'userCanUpgrade'
| 'endpoint'
| 'codyApiVersion'
| 'isFireworksTracingEnabled'
>
) {}

public chat(
constructor(private completions: SourcegraphCompletionsClient) {}

public async chat(
messages: Message[],
params: Partial<ChatParameters> & Pick<ChatParameters, 'maxTokensToSample'>,
abortSignal?: AbortSignal
): AsyncGenerator<CompletionGeneratorValue> {
const authStatus = this.getAuthStatus()
const useApiV1 = authStatus.codyApiVersion >= 1 && params.model?.includes('claude-3')
): Promise<AsyncGenerator<CompletionGeneratorValue>> {
const [versions, authStatus_] = await Promise.all([
currentSiteVersion(),
await firstValueFrom(authStatus),
])
if (!versions) {
throw new Error('unable to determine Cody API version')
}
if (!authStatus_.authenticated) {
throw new Error('not authenticated')
}

const useApiV1 = versions.codyAPIVersion >= 1 && params.model?.includes('claude-3')
const isLastMessageFromHuman = messages.length > 0 && messages.at(-1)!.speaker === 'human'

const isFireworks = params?.model?.startsWith('fireworks/')
Expand All @@ -59,13 +61,14 @@ export class ChatClient {

// Enabled Fireworks tracing for Sourcegraph teammates.
// https://readme.fireworks.ai/docs/enabling-tracing

const customHeaders: Record<string, string> =
isFireworks && authStatus.isFireworksTracingEnabled ? { 'X-Fireworks-Genie': 'true' } : {}
isFireworks && authStatus_.isFireworksTracingEnabled ? { 'X-Fireworks-Genie': 'true' } : {}

return this.completions.stream(
completionParams,
{
apiVersion: useApiV1 ? authStatus.codyApiVersion : 0,
apiVersion: useApiV1 ? versions.codyAPIVersion : 0,
customHeaders,
},
abortSignal
Expand Down
7 changes: 2 additions & 5 deletions lib/shared/src/cody-ignore/context-filters-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ describe('ContextFiltersProvider', () => {
provider = new ContextFiltersProvider()
vi.useFakeTimers()

vi.spyOn(graphqlClient, 'isCodyEnabled').mockResolvedValue({ enabled: true, version: '6.0.0' })
vi.spyOn(graphqlClient, 'getSiteVersion').mockResolvedValue('6.0.0')
})

afterEach(() => {
Expand Down Expand Up @@ -298,10 +298,7 @@ describe('ContextFiltersProvider', () => {
}

it('should handle the case when version is older than the supported version', async () => {
vi.spyOn(graphqlClient, 'isCodyEnabled').mockResolvedValue({
enabled: true,
version: '5.3.2',
})
vi.spyOn(graphqlClient, 'getSiteVersion').mockResolvedValue('5.3.2')
await initProviderWithContextFilters({
include: [{ repoNamePattern: '^github\\.com/sourcegraph/cody' }],
exclude: [{ repoNamePattern: '^github\\.com/sourcegraph/sourcegraph' }],
Expand Down
7 changes: 7 additions & 0 deletions lib/shared/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,10 @@ export * from './singletons'
export * from './auth/authStatus'
export { fetchLocalOllamaModels } from './llm-providers/ollama/utils'
export * from './editor/editorState'
export { configOverwrites } from './models/configOverwrites'
export { siteVersion, currentSiteVersion } from './sourcegraph-api/siteVersion'
export {
currentUserProductSubscription,
type UserProductSubscription,
cachedUserProductSubscription,
} from './sourcegraph-api/userProductSubscription'
15 changes: 9 additions & 6 deletions lib/shared/src/misc/observable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -546,13 +546,16 @@ export function pluck<T>(...keyPath: any[]): (input: ObservableLike<T>) => Obser
}

export function pick<T, K extends keyof T>(
key: K
...keys: K[]
): (input: ObservableLike<T>) => Observable<Pick<T, K>> {
return map(
value =>
({
[key]: value[key],
}) as Pick<T, K>
return map(value =>
keys.reduce(
(acc, key) => {
acc[key] = value[key]
return acc
},
{} as Pick<T, K>
)
)
}

Expand Down
8 changes: 7 additions & 1 deletion lib/shared/src/misc/rpc/webviewAPI.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Observable } from 'observable-fns'
import type { AuthStatus, ModelsData, ResolvedConfiguration } from '../..'
import type { AuthStatus, ModelsData, ResolvedConfiguration, UserProductSubscription } from '../..'
import type { ChatMessage, UserLocalHistory } from '../../chat/transcript/messages'
import type { ContextItem } from '../../codebase-context/messages'
import type { CodyCommand } from '../../commands/types'
Expand Down Expand Up @@ -76,6 +76,11 @@ export interface WebviewToExtensionAPI {
* The current user's chat history.
*/
userHistory(): Observable<UserLocalHistory | null>

/**
* The current user's product subscription information (Cody Free/Pro).
*/
userProductSubscription(): Observable<UserProductSubscription | null>
}

export function createExtensionAPI(
Expand All @@ -100,6 +105,7 @@ export function createExtensionAPI(
authStatus: proxyExtensionAPI(messageAPI, 'authStatus'),
transcript: proxyExtensionAPI(messageAPI, 'transcript'),
userHistory: proxyExtensionAPI(messageAPI, 'userHistory'),
userProductSubscription: proxyExtensionAPI(messageAPI, 'userProductSubscription'),
}
}

Expand Down
45 changes: 45 additions & 0 deletions lib/shared/src/models/configOverwrites.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Observable, map } from 'observable-fns'
import { authStatus } from '../auth/authStatus'
import { logError } from '../logger'
import { distinctUntilChanged, pick, promiseFactoryToObservable } from '../misc/observable'
import { pendingOperation, switchMapReplayOperation } from '../misc/observableOperation'
import { type CodyLLMSiteConfiguration, graphqlClient } from '../sourcegraph-api/graphql/client'
import { isError } from '../utils'

/**
* Observe the model-related config overwrites on the server for the currently authenticated user.
*/
export const configOverwrites: Observable<CodyLLMSiteConfiguration | null | typeof pendingOperation> =
authStatus.pipe(
pick('authenticated', 'endpoint', 'pendingValidation'),
distinctUntilChanged(),
switchMapReplayOperation(
(
authStatus
): Observable<CodyLLMSiteConfiguration | Error | null | typeof pendingOperation> => {
if (authStatus.pendingValidation) {
return Observable.of(pendingOperation)
}

if (!authStatus.authenticated) {
return Observable.of(null)
}

return promiseFactoryToObservable(signal =>
graphqlClient.getCodyLLMConfiguration(signal)
).pipe(
map((result): CodyLLMSiteConfiguration | null | typeof pendingOperation => {
if (isError(result)) {
logError(
'configOverwrites',
`Failed to get Cody LLM configuration from ${authStatus.endpoint}: ${result}`
)
return null
}
return result ?? null
})
)
}
),
map(result => (isError(result) ? null : result)) // the operation catches its own errors, so errors will never get here
)
27 changes: 27 additions & 0 deletions lib/shared/src/models/modelsService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { currentAuthStatus, mockAuthStatus } from '../auth/authStatus'
import { AUTH_STATUS_FIXTURE_AUTHED, type AuthenticatedAuthStatus } from '../auth/types'
import { firstValueFrom } from '../misc/observable'
import { DOTCOM_URL } from '../sourcegraph-api/environments'
import * as userProductSubscriptionModule from '../sourcegraph-api/userProductSubscription'
import type { UserProductSubscription } from '../sourcegraph-api/userProductSubscription'
import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants'
import { getMockedDotComClientModels } from './dotcom'
import type { Model } from './model'
Expand Down Expand Up @@ -31,11 +33,15 @@ describe('modelsService', () => {
...AUTH_STATUS_FIXTURE_AUTHED,
endpoint: DOTCOM_URL.toString(),
authenticated: true,
}
const freeUserSub: UserProductSubscription = {
userCanUpgrade: true,
}

const codyProAuthStatus: AuthenticatedAuthStatus = {
...freeUserAuthStatus,
}
const codyProSub: UserProductSubscription = {
userCanUpgrade: false,
}

Expand Down Expand Up @@ -158,6 +164,9 @@ describe('modelsService', () => {
let modelsService: ModelsService
beforeEach(() => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
modelsService = modelsServiceWithModels([model1chat, model2chat, model3all, model4edit])
})

Expand Down Expand Up @@ -245,36 +254,54 @@ describe('modelsService', () => {

it('returns false for unknown model', async () => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable('unknown-model'))).toBe(false)
})

it('allows enterprise user to use any model', async () => {
mockAuthStatus(enterpriseAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(null)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('allows Cody Pro user to use Pro and Free models', async () => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('allows free user to use only Free models', async () => {
mockAuthStatus(freeUserAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(freeUserSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('handles model passed as string', async () => {
mockAuthStatus(freeUserAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(freeUserSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel.id))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel.id))).toBe(false)

mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel.id))).toBe(true)
})
})
Expand Down
Loading
Loading