Skip to content

Commit 9162f3a

Browse files
refactor: use async-openai CompletionRequest (NVIDIA#310)
1 parent 057f8f4 commit 9162f3a

File tree

11 files changed

+71
-375
lines changed

11 files changed

+71
-375
lines changed

lib/llm/src/http/service/openai.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,19 @@ async fn completions(
140140
let request_id = uuid::Uuid::new_v4().to_string();
141141

142142
// todo - decide on default
143-
let streaming = request.stream.unwrap_or(false);
143+
let streaming = request.inner.stream.unwrap_or(false);
144144

145145
// update the request to always stream
146-
let request = CompletionRequest {
146+
let inner = async_openai::types::CreateCompletionRequest {
147147
stream: Some(true),
148-
..request
148+
..request.inner
149149
};
150150

151+
let request = CompletionRequest { inner, nvext: None };
152+
151153
// todo - make the protocols be optional for model name
152154
// todo - when optional, if none, apply a default
153-
let model = &request.model;
155+
let model = &request.inner.model;
154156

155157
// todo - error handling should be more robust
156158
let engine = state

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl OAIChatLikeRequest for CompletionRequest {
6060
let message = async_openai::types::ChatCompletionRequestMessage::User(
6161
async_openai::types::ChatCompletionRequestUserMessage {
6262
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
63-
self.prompt.clone(),
63+
crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
6464
),
6565
name: None,
6666
},

lib/llm/src/protocols/openai.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ pub mod nvext;
2222
use anyhow::Result;
2323
use serde::{Deserialize, Serialize};
2424
use std::{
25-
collections::HashMap,
2625
fmt::Display,
2726
ops::{Add, Div, Mul, Sub},
2827
};
29-
use validator::ValidationError;
3028

3129
use super::{
3230
common::{self, SamplingOptionsProvider, StopConditionsProvider},
@@ -263,17 +261,6 @@ pub struct GenericCompletionResponse<C>
263261
// TODO() - add NvResponseExtention
264262
}
265263

266-
fn validate_logit_bias(logit_bias: &HashMap<String, i32>) -> Result<(), ValidationError> {
267-
for key in logit_bias.keys() {
268-
if key.parse::<i32>().is_err() {
269-
return Err(
270-
ValidationError::new("logit_bias").with_message("Keys must be integers".into())
271-
);
272-
}
273-
}
274-
Ok(())
275-
}
276-
277264
// todo - move to common location
278265
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
279266
where

0 commit comments

Comments
 (0)