Skip to content

Commit

Permalink
Bump version 0.2.0
Browse files Browse the repository at this point in the history
Signed-off-by: Paolo Di Tommaso <[email protected]>
  • Loading branch information
pditommaso committed Mar 20, 2024
1 parent 3c54558 commit afeeb27
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 11 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ nextflow run <my script>

See the folder [examples] for more examples

### Options

The `prompt` operator support those options

| name | description |
|---------------|-------------|
| model | The AI model to be used (default: `gpt-3.5-turbo`) |
| maxTokens | The maximum number of tokens that can be generated in the chat completion |
| schema | The expected strcuture for the result object represented as map object in which represent the attribute name and the value the attribute type |
| temperature | What sampling temperature to use, between 0 and 2 (default: `0.7`) |

### Configuration file

The following config options can be specified in the `nextflow.config` file:


| name | description |
|---------------|-------------|
| gpt.apiKey | Your OpenAI API key. If missing it uses the `OPENAI_API_KEY` env variable |
| gpt.endpoint | The OpenAI endpoint (defualt: `https://api.openai.com`) |
| gpt.model | The AI model to be used (default: `gpt-3.5-turbo`) |
| gpt.maxTokens | The maximum number of tokens that can be generated in the chat completion |
| gpt.temperature | What sampling temperature to use, between 0 and 2 (default: `0.7`) |


## Testing and debugging

Expand Down
2 changes: 1 addition & 1 deletion examples/example2.nf
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Who won most gold medals in swimming and Athletics categories during Barcelona 1
'''

channel .of(query)
.prompt(schema: [athlete: 'string', numberOfMedals: 'number', location:'string'])
.prompt(schema: [athlete: 'string', numberOfMedals: 'number', location:'string', sport:'string'])
.view()
2 changes: 1 addition & 1 deletion examples/example3.nf
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include { prompt } from 'plugin/nf-gpt'
channel
.fromList(['Barcelona, 1992', 'London, 2012'])
.combine(['Swimming', 'Athletics'])
.prompt(schema: [athlete: 'string', numberOfMedals: 'number']) { edition, sport ->
.prompt(schema: [athlete: 'string', numberOfMedals: 'number', location: 'string', sport: 'string']) { edition, sport ->
"Who won most gold medals in $sport category during $edition olympic games?"
}
.view()
12 changes: 12 additions & 0 deletions plugins/nf-gpt/src/main/nextflow/gpt/config/GptConfig.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class GptConfig {

static final String DEFAULT_ENDPOINT = 'https://api.openai.com'
static final String DEFAULT_MODEL = 'gpt-3.5-turbo'
static final Double DEFAULT_TEMPERATURE = 0.7d

private String endpoint
private String apiKey
private String model
private Double temperature
private Integer maxTokens

static GptConfig config(Session session) {
new GptConfig(session.config.ai as Map ?: Collections.emptyMap(), SysEnv.get())
Expand All @@ -45,6 +48,7 @@ class GptConfig {
this.endpoint = opts.endpoint ?: DEFAULT_ENDPOINT
this.model = opts.model ?: DEFAULT_MODEL
this.apiKey = opts.apiKey ?: env.get('OPENAI_API_KEY')
this.temperature = opts.temperature!=null ? temperature as Double : DEFAULT_TEMPERATURE
}

String endpoint() {
Expand All @@ -58,4 +62,12 @@ class GptConfig {
String model() {
return model
}

Double temperature() {
return temperature
}

Integer maxTokens() {
return maxTokens
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class GptPromptExtension extends PluginExtensionPoint {
static final private Map VALID_PROMPT_OPTS = [
model: String,
schema: Map,
debug: Boolean
debug: Boolean,
temperature: Double,
maxTokens: Integer
]

private Session session
Expand Down
38 changes: 31 additions & 7 deletions plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package nextflow.gpt.prompt

import dev.langchain4j.data.message.ChatMessage
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.openai.OpenAiChatModel
import groovy.json.JsonSlurper
Expand All @@ -40,6 +41,8 @@ class GptPromptModel {

private String model
private boolean debug
private Double temperature
private Integer maxTokens

GptPromptModel(Session session) {
this.config = GptConfig.config(session)
Expand All @@ -55,32 +58,53 @@ class GptPromptModel {
return this
}

GptPromptModel withTemperature(Double d) {
this.temperature = d
return this
}

GptPromptModel withMaxToken(Integer i) {
this.maxTokens = i
return this
}

GptPromptModel build() {
final modelName = model ?: config.model()
log.debug "Creating OpenAI chat model: $modelName; api-key: ${StringUtils.redact(config.apiKey())}"
final temp = temperature ?: config.temperature()
final tokens = maxTokens ?: config.maxTokens()
log.debug "Creating OpenAI chat model: $modelName; api-key: ${StringUtils.redact(config.apiKey())}; temperature: $temp; maxTokens: ${maxTokens}"
client = OpenAiChatModel.builder()
.apiKey(config.apiKey())
.modelName(modelName)
.logRequests(debug)
.logResponses(debug)
.temperature(temperature)
.maxTokens(maxTokens)
.responseFormat("json_object")
.build();
return this
}

List<Map<String,Object>> prompt(String query, Map schema) {
if( !query )
List<Map<String,Object>> prompt(List<ChatMessage> messages, Map schema) {
if( !messages )
throw new IllegalArgumentException("Missing AI prompt")
final content = query + '. ' + renderSchema(schema)
final msg = UserMessage.from(content)
final all = new ArrayList(messages)
all.add(SystemMessage.from(renderSchema(schema)))
if( debug )
log.debug "AI message: $msg"
final json = client.generate(List.<ChatMessage>of(msg)).content().text()
log.debug "AI message: $all"
final json = client.generate(all).content().text()
if( debug )
log.debug "AI response: $json"
return decodeResponse(new JsonSlurper().parseText(json), schema)
}

List<Map<String,Object>> prompt(String query, Map schema) {
if( !query )
throw new IllegalArgumentException("Missing AI prompt")
final msg = UserMessage.from(query)
return prompt(List.<ChatMessage>of(msg), schema)
}

static protected String renderSchema(Map schema) {
return 'You must answer strictly in the following JSON format: {"result": [' + schema0(schema) + '] }'
}
Expand Down
2 changes: 1 addition & 1 deletion plugins/nf-gpt/src/resources/META-INF/MANIFEST.MF
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Manifest-Version: 1.0
Plugin-Class: nextflow.gpt.GptPlugin
Plugin-Id: nf-gpt
Plugin-Version: 0.1.0
Plugin-Version: 0.2.0
Plugin-Provider: Seqera Labs
Plugin-Requires: >=24.01.0-edge

0 comments on commit afeeb27

Please sign in to comment.