Skip to content

Commit

Permalink
do not block on fetching the user's Cody Pro subscription status when…
Browse files Browse the repository at this point in the history
… authing
  • Loading branch information
sqs committed Oct 4, 2024
1 parent f14f12b commit bf594bf
Show file tree
Hide file tree
Showing 21 changed files with 233 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ data class AuthStatus(
val primaryEmail: String? = null,
val displayName: String? = null,
val avatarURL: String? = null,
val userCanUpgrade: Boolean? = null,
)

18 changes: 6 additions & 12 deletions lib/shared/src/auth/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { isDotCom } from '../sourcegraph-api/environments'
import type { UserProductSubscription } from '../sourcegraph-api/userProductSubscription'

/**
* The authentication status, which includes representing the state when authentication failed or
Expand Down Expand Up @@ -28,14 +29,7 @@ export interface AuthenticatedAuthStatus {
primaryEmail?: string
displayName?: string
avatarURL?: string
/**
* Whether the users account can be upgraded.
*
* This is `true` if the user is on dotCom and has not already upgraded. It
* is used to customize rate limit messages and show additional upgrade
* buttons in the UI.
*/
userCanUpgrade?: boolean

pendingValidation: boolean
}

Expand Down Expand Up @@ -69,12 +63,12 @@ export const AUTH_STATUS_FIXTURE_AUTHED_DOTCOM: AuthenticatedAuthStatus = {
endpoint: 'https://sourcegraph.com',
}

export function isCodyProUser(authStatus: AuthStatus): boolean {
return isDotCom(authStatus) && authStatus.authenticated && !authStatus.userCanUpgrade
export function isCodyProUser(authStatus: AuthStatus, sub: UserProductSubscription | null): boolean {
return isDotCom(authStatus) && authStatus.authenticated && sub !== null && !sub.userCanUpgrade
}

export function isFreeUser(authStatus: AuthStatus): boolean {
return isDotCom(authStatus) && authStatus.authenticated && !!authStatus.userCanUpgrade
export function isFreeUser(authStatus: AuthStatus, sub: UserProductSubscription | null): boolean {
return isDotCom(authStatus) && authStatus.authenticated && sub !== null && !!sub.userCanUpgrade
}

export function isEnterpriseUser(authStatus: AuthStatus): boolean {
Expand Down
5 changes: 5 additions & 0 deletions lib/shared/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,8 @@ export { fetchLocalOllamaModels } from './llm-providers/ollama/utils'
export * from './editor/editorState'
export { configOverwrites } from './models/configOverwrites'
export { siteVersion, currentSiteVersion } from './sourcegraph-api/siteVersion'
export {
currentUserProductSubscription,
type UserProductSubscription,
cachedUserProductSubscription,
} from './sourcegraph-api/userProductSubscription'
8 changes: 7 additions & 1 deletion lib/shared/src/misc/rpc/webviewAPI.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Observable } from 'observable-fns'
import type { AuthStatus, ModelsData, ResolvedConfiguration } from '../..'
import type { AuthStatus, ModelsData, ResolvedConfiguration, UserProductSubscription } from '../..'
import type { ChatMessage, UserLocalHistory } from '../../chat/transcript/messages'
import type { ContextItem } from '../../codebase-context/messages'
import type { CodyCommand } from '../../commands/types'
Expand Down Expand Up @@ -76,6 +76,11 @@ export interface WebviewToExtensionAPI {
* The current user's chat history.
*/
userHistory(): Observable<UserLocalHistory | null>

/**
* The current user's product subscription information (Cody Free/Pro).
*/
userProductSubscription(): Observable<UserProductSubscription | null>
}

export function createExtensionAPI(
Expand All @@ -100,6 +105,7 @@ export function createExtensionAPI(
authStatus: proxyExtensionAPI(messageAPI, 'authStatus'),
transcript: proxyExtensionAPI(messageAPI, 'transcript'),
userHistory: proxyExtensionAPI(messageAPI, 'userHistory'),
userProductSubscription: proxyExtensionAPI(messageAPI, 'userProductSubscription'),
}
}

Expand Down
27 changes: 27 additions & 0 deletions lib/shared/src/models/modelsService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { currentAuthStatus, mockAuthStatus } from '../auth/authStatus'
import { AUTH_STATUS_FIXTURE_AUTHED, type AuthenticatedAuthStatus } from '../auth/types'
import { firstValueFrom } from '../misc/observable'
import { DOTCOM_URL } from '../sourcegraph-api/environments'
import * as userProductSubscriptionModule from '../sourcegraph-api/userProductSubscription'
import type { UserProductSubscription } from '../sourcegraph-api/userProductSubscription'
import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants'
import { getMockedDotComClientModels } from './dotcom'
import type { Model } from './model'
Expand Down Expand Up @@ -31,11 +33,15 @@ describe('modelsService', () => {
...AUTH_STATUS_FIXTURE_AUTHED,
endpoint: DOTCOM_URL.toString(),
authenticated: true,
}
const freeUserSub: UserProductSubscription = {
userCanUpgrade: true,
}

const codyProAuthStatus: AuthenticatedAuthStatus = {
...freeUserAuthStatus,
}
const codyProSub: UserProductSubscription = {
userCanUpgrade: false,
}

Expand Down Expand Up @@ -158,6 +164,9 @@ describe('modelsService', () => {
let modelsService: ModelsService
beforeEach(() => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
modelsService = modelsServiceWithModels([model1chat, model2chat, model3all, model4edit])
})

Expand Down Expand Up @@ -245,36 +254,54 @@ describe('modelsService', () => {

it('returns false for unknown model', async () => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable('unknown-model'))).toBe(false)
})

it('allows enterprise user to use any model', async () => {
mockAuthStatus(enterpriseAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(null)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('allows Cody Pro user to use Pro and Free models', async () => {
mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('allows free user to use only Free models', async () => {
mockAuthStatus(freeUserAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(freeUserSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(enterpriseModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel))).toBe(false)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel))).toBe(true)
})

it('handles model passed as string', async () => {
mockAuthStatus(freeUserAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(freeUserSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(freeModel.id))).toBe(true)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel.id))).toBe(false)

mockAuthStatus(codyProAuthStatus)
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of(codyProSub)
)
expect(await firstValueFrom(modelsService.isModelAvailable(proModel.id))).toBe(true)
})
})
Expand Down
44 changes: 34 additions & 10 deletions lib/shared/src/models/modelsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import {
skipPendingOperation,
} from '../misc/observableOperation'
import { ClientConfigSingleton } from '../sourcegraph-api/clientConfig'
import {
type UserProductSubscription,
userProductSubscription,
} from '../sourcegraph-api/userProductSubscription'
import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants'
import { configOverwrites } from './configOverwrites'
import { type Model, type ServerModel, modelTier } from './model'
Expand Down Expand Up @@ -325,15 +329,26 @@ export class ModelsService {
}

public getDefaultModel(type: ModelUsage): Observable<Model | undefined | typeof pendingOperation> {
return combineLatest(this.getModelsByType(type), this.modelsChanges, authStatus).pipe(
map(([models, modelsData, authStatus]) => {
if (models === pendingOperation || modelsData === pendingOperation) {
return combineLatest(
this.getModelsByType(type),
this.modelsChanges,
authStatus,
userProductSubscription
).pipe(
map(([models, modelsData, authStatus, userProductSubscription]) => {
if (
models === pendingOperation ||
modelsData === pendingOperation ||
userProductSubscription === pendingOperation
) {
return pendingOperation
}

// Free users can only use the default free model, so we just find the first model they can use
const firstModelUserCanUse = models.find(
m => this._isModelAvailable(modelsData, authStatus, m) === true
m =>
this._isModelAvailable(modelsData, authStatus, userProductSubscription, m) ===
true
)

if (modelsData.preferences) {
Expand All @@ -343,7 +358,15 @@ export class ModelsService {
modelsData,
modelsData.preferences.selected[type] ?? modelsData.preferences.defaults[type]
)
if (selected && this._isModelAvailable(modelsData, authStatus, selected) === true) {
if (
selected &&
this._isModelAvailable(
modelsData,
authStatus,
userProductSubscription,
selected
) === true
) {
return selected
}
}
Expand Down Expand Up @@ -404,11 +427,11 @@ export class ModelsService {
}

public isModelAvailable(model: string | Model): Observable<boolean | typeof pendingOperation> {
return combineLatest(authStatus, this.modelsChanges).pipe(
map(([authStatus, modelsData]) =>
modelsData === pendingOperation
return combineLatest(authStatus, this.modelsChanges, userProductSubscription).pipe(
map(([authStatus, modelsData, userProductSubscription]) =>
modelsData === pendingOperation || userProductSubscription === pendingOperation
? pendingOperation
: this._isModelAvailable(modelsData, authStatus, model)
: this._isModelAvailable(modelsData, authStatus, userProductSubscription, model)
),
distinctUntilChanged()
)
Expand All @@ -417,6 +440,7 @@ export class ModelsService {
private _isModelAvailable(
modelsData: ModelsData,
authStatus: AuthStatus,
sub: UserProductSubscription | null,
model: string | Model
): boolean {
const resolved = this.resolveModel(modelsData, model)
Expand All @@ -432,7 +456,7 @@ export class ModelsService {
// A Cody Pro user can use any Free or Pro model, but not Enterprise.
// (But in reality, Sourcegraph.com wouldn't serve any Enterprise-only models to
// Cody Pro users anyways.)
if (isCodyProUser(authStatus)) {
if (isCodyProUser(authStatus, sub)) {
return (
tier !== 'enterprise' &&
!resolved.tags.includes(ModelTag.Waitlist) &&
Expand Down
4 changes: 4 additions & 0 deletions lib/shared/src/models/sync.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import { pendingOperation, skipPendingOperation } from '../misc/observableOperation'
import type { CodyClientConfig } from '../sourcegraph-api/clientConfig'
import type { CodyLLMSiteConfiguration } from '../sourcegraph-api/graphql/client'
import * as userProductSubscriptionModule from '../sourcegraph-api/userProductSubscription'
import type { PartialDeep } from '../utils'
import {
type Model,
Expand Down Expand Up @@ -216,6 +217,9 @@ describe('server sent models', async () => {
})

it("sets server models and default models if they're not already set", async () => {
vi.spyOn(userProductSubscriptionModule, 'userProductSubscription', 'get').mockReturnValue(
Observable.of({ userCanUpgrade: true })
)
// expect all defaults to be set
expect(await firstValueFrom(modelsService.getDefaultChatModel())).toBe(opus.id)
expect(await firstValueFrom(modelsService.getDefaultEditModel())).toBe(opus.id)
Expand Down
97 changes: 97 additions & 0 deletions lib/shared/src/sourcegraph-api/userProductSubscription.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { Observable, map } from 'observable-fns'
import { authStatus } from '../auth/authStatus'
import { logError } from '../logger'
import {
debounceTime,
distinctUntilChanged,
pick,
promiseFactoryToObservable,
storeLastValue,
} from '../misc/observable'
import {
firstResultFromOperation,
pendingOperation,
switchMapReplayOperation,
} from '../misc/observableOperation'
import { isError } from '../utils'
import { isDotCom } from './environments'
import { graphqlClient } from './graphql'

export interface UserProductSubscription {
// TODO(sqs): this is the only field related to the user's subscription we were using previously
// in AuthStatus, so start with just it and we can add more.

/**
* Whether the user is on Cody Free (i.e., can upgrade to Cody Pro). This is `false` for
* enterprise users because they already have a higher degree of access than Cody Free/Pro.
*
* It's used to customize rate limit messages and show upgrade buttons in the UI.
*/
userCanUpgrade: boolean
}

/**
* Observe the currently authenticated user's Cody subscription status (for Sourcegraph.com Cody
* Free/Pro users only).
*/
export const userProductSubscription: Observable<
UserProductSubscription | null | typeof pendingOperation
> = authStatus.pipe(
pick('authenticated', 'endpoint', 'pendingValidation'),
distinctUntilChanged(),
debounceTime(0),
switchMapReplayOperation(
(authStatus): Observable<UserProductSubscription | Error | null | typeof pendingOperation> => {
if (authStatus.pendingValidation) {
return Observable.of(pendingOperation)
}

if (!authStatus.authenticated) {
return Observable.of(null)
}

if (!isDotCom(authStatus)) {
return Observable.of(null)
}

return promiseFactoryToObservable(signal =>
graphqlClient.getCurrentUserCodySubscription(signal)
).pipe(
map((sub): UserProductSubscription | null | typeof pendingOperation => {
if (isError(sub)) {
logError(
'userProductSubscription',
`Failed to get the Cody product subscription info from ${authStatus.endpoint}: ${sub}`
)
return null
}
const isActiveProUser =
sub !== null && 'plan' in sub && sub.plan === 'PRO' && sub.status !== 'PENDING'
return {
userCanUpgrade: !isActiveProUser,
}
})
)
}
),
map(result => (isError(result) ? null : result)) // the operation catches its own errors, so errors will never get here
)

const userProductSubscriptionStorage = storeLastValue(userProductSubscription)

/**
* Get the current user's product subscription info. If authentication is pending, it awaits
* successful authentication.
*/
export function currentUserProductSubscription(): Promise<UserProductSubscription | null> {
return firstResultFromOperation(userProductSubscriptionStorage.observable)
}

/**
* Get the current user's last-known product subscription info. Using this introduce a race
* condition if auth is pending.
*/
export function cachedUserProductSubscription(): UserProductSubscription | null {
const value = userProductSubscriptionStorage.value.last
return value === pendingOperation || !value ? null : value
}
Loading

0 comments on commit bf594bf

Please sign in to comment.