Skip to content

Commit

Permalink
Various
Browse files Browse the repository at this point in the history
  • Loading branch information
xyproto committed Aug 21, 2024
1 parent 29c3c4a commit 90c0207
Showing 1 changed file with 83 additions and 42 deletions.
125 changes: 83 additions & 42 deletions v2/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
var (
fixLineMut sync.Mutex
geminiEnabled = env.Has("GCP_PROJECT") || env.Has("PROJECT_ID")
geminiModel = env.Str("GEMINI_MODEL", "gemini-1.5-flash")
)

func (e *Editor) ProgrammingLanguage() bool {
Expand All @@ -49,18 +50,41 @@ func (e *Editor) GenerateTokens(geminiClient *simplegemini.GeminiClient, prompt
if geminiClient == nil {
return errors.New("no Gemini client")
}
ctx, cancel := context.WithTimeout(context.Background(), geminiClient.Timeout)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

_, err := geminiClient.SubmitToClientStreaming(ctx, func(token string) {
if !(len(token) <= 6 && strings.HasPrefix(token, "```")) {
newToken(token)
}
if !e.generatingTokens {
cancel()
}
})
return err
streamEnded := make(chan struct{})
tokenBuffer := []string{}

go func() {
geminiClient.SubmitToClientStreaming(ctx, func(token string) {
if !(e.mode != mode.Markdown && strings.Contains(token, "```")) {
// Log each token for debugging
logf("Received token: %s", token)

// Append token to buffer
tokenBuffer = append(tokenBuffer, token)
newToken(token)
}
if !e.generatingTokens {
cancel()
}
})
close(streamEnded)
}()

// Wait for the stream to end
<-streamEnded

// Ensure remaining tokens are processed
for _, token := range tokenBuffer {
newToken(token)
}

// Set generatingTokens to false when done
e.generatingTokens = false

return nil
}

func (e *Editor) FixLine(c *vt100.Canvas, status *StatusBar, lineIndex LineIndex, disableFixAsYouTypeOnError bool) {
Expand All @@ -75,9 +99,9 @@ func (e *Editor) FixLine(c *vt100.Canvas, status *StatusBar, lineIndex LineIndex
return
}

temperature := float32(0.0)
temperature := env.Float32("GEMINI_TEMPERATURE", 0.0)

geminiClient, err := simplegemini.NewWithTimeout("gemini-1.5-pro", temperature, 10*time.Second)
geminiClient, err := simplegemini.NewWithTimeout(geminiModel, temperature, 10*time.Second)
if err != nil {
status.SetErrorMessage("Failed to create Gemini client")
status.Show(c, e)
Expand All @@ -95,7 +119,7 @@ func (e *Editor) FixLine(c *vt100.Canvas, status *StatusBar, lineIndex LineIndex
status.Show(c, e)
return
}
maxTokens := 16000 - amountOfPromptTokens
maxTokens := 8192 - amountOfPromptTokens
if maxTokens < 1 {
status.SetErrorMessage("Gemini API request is too long")
status.Show(c, e)
Expand All @@ -119,6 +143,12 @@ func (e *Editor) FixLine(c *vt100.Canvas, status *StatusBar, lineIndex LineIndex
}
generatedLine += line
e.SetCurrentLine(currentLeadingWhitespace + e.AddSpaceAfterComments(generatedLine))
e.MakeConsistent()
e.DrawLines(c, true, false)
e.redrawCursor = true

// Log each line insertion for debugging
logf("Inserted line: %s", generatedLine)
}
}); err != nil {
e.redrawCursor = true
Expand All @@ -132,6 +162,11 @@ func (e *Editor) FixLine(c *vt100.Canvas, status *StatusBar, lineIndex LineIndex
return
}
}

// Final refresh to ensure all tokens are drawn
e.MakeConsistent()
e.DrawLines(c, true, false)
e.redrawCursor = true
}

func (e *Editor) FixCodeOrText(c *vt100.Canvas, status *StatusBar, disableFixAsYouTypeOnError bool) {
Expand All @@ -150,13 +185,6 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
return
}

geminiClient, err := simplegemini.NewWithTimeout("gemini-1.5-pro", 0.8, 10*time.Second)
if err != nil {
status.SetErrorMessage("Failed to create Gemini client")
status.Show(c, e)
return
}

trimmedLine := e.TrimmedLine()

go func() {
Expand All @@ -171,12 +199,19 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
generationType := generateText
if e.ProgrammingLanguage() {
generationType = generateCode
if prompt == "" {
generationType = continueCode
}
}

temperature := env.Float32("GEMINI_TEMPERATURE", 0.8)
if generationType == generateCode || generationType == continueCode {
temperature = env.Float32("GEMINI_TEMPERATURE", 0.0)
}

geminiClient, err := simplegemini.NewWithTimeout(geminiModel, temperature, 10*time.Second)
if err != nil {
status.SetErrorMessage("Failed to create Gemini client")
status.Show(c, e)
return
}

amountOfPromptTokens, err := geminiClient.CountTextTokens(prompt)
if err != nil {
Expand All @@ -185,13 +220,20 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
return
}

maxTokens := 8192 - amountOfPromptTokens
if maxTokens < 1 {
status.SetErrorMessage("Gemini API request is too long")
status.Show(c, e)
return
}

switch generationType {
case generateCode:
prompt += ". " + fmt.Sprintf(codePrompt, e.mode.String())
case continueCode:
prompt += ". " + fmt.Sprintf(continuePrompt, e.mode.String()) + "\n"
startTokens := strings.Fields(e.String())
gatherNTokens := 16000 - amountOfPromptTokens
gatherNTokens := 8192 - amountOfPromptTokens
if len(startTokens) > gatherNTokens {
startTokens = startTokens[len(startTokens)-gatherNTokens:]
}
Expand All @@ -214,13 +256,6 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
}
status.Show(c, e)

maxTokens := 16000 - amountOfPromptTokens
if maxTokens < 1 {
status.SetErrorMessage("Gemini API request is too long")
status.Show(c, e)
return
}

currentLeadingWhitespace := e.LeadingWhitespace()
e.generatingTokens = true
first := true
Expand All @@ -236,14 +271,17 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
}
generatedLine += line
e.SetCurrentLine(currentLeadingWhitespace + e.AddSpaceAfterComments(generatedLine))
e.MakeConsistent()
e.DrawLines(c, true, false)
e.redrawCursor = true
if first {
e.DeleteCurrentLineMoveBookmark(bookmark)
first = false
}

// Log each line insertion for debugging
logf("Inserted line: %s", generatedLine)
}
e.MakeConsistent()
e.DrawLines(c, true, false)
e.redrawCursor = true
}); err != nil {
e.redrawCursor = true
if !strings.Contains(err.Error(), "context") {
Expand All @@ -253,16 +291,19 @@ func (e *Editor) GenerateCodeOrText(c *vt100.Canvas, status *StatusBar, bookmark
return
}
}
e.End(c)

if e.generatingTokens {
if first {
status.SetMessageAfterRedraw("Nothing was generated")
} else {
status.SetMessageAfterRedraw("Done")
}
// Final refresh to ensure all tokens are drawn
e.MakeConsistent()
e.DrawLines(c, true, false)
e.redrawCursor = true

// Ensure e.generatingTokens is set to false when done
e.generatingTokens = false

if first {
status.SetMessageAfterRedraw("Nothing was generated")
} else {
status.SetMessageAfterRedraw("Stopped")
status.SetMessageAfterRedraw("Done")
}

e.RedrawAtEndOfKeyLoop(c, status)
Expand Down

0 comments on commit 90c0207

Please sign in to comment.