Merge branch 'main' into royh-batchembed
This commit is contained in:
63
llm/ext_server/server.cpp
vendored
63
llm/ext_server/server.cpp
vendored
@@ -56,7 +56,6 @@ struct server_params {
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::vector<std::string> api_keys;
|
||||
std::string public_path = "examples/server/public";
|
||||
std::string chat_template = "";
|
||||
int32_t port = 8080;
|
||||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
@@ -427,16 +426,6 @@ struct llama_server_context
|
||||
return true;
|
||||
}
|
||||
|
||||
void validate_model_chat_template(server_params & sparams) {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
std::vector<char> buf(1);
|
||||
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
||||
if (res < 0) {
|
||||
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
sparams.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
void initialize() {
|
||||
// create slots
|
||||
all_slots_are_idle = true;
|
||||
@@ -1661,26 +1650,41 @@ struct llama_server_context
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
char buf[256];
|
||||
llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
||||
bool gemma2 = strcmp(buf, "gemma2") == 0;
|
||||
|
||||
int32_t truncate_at = slot.n_ctx;
|
||||
|
||||
// truncate at 2/3 of the context length for gemma2 models
|
||||
// as they do not support context shifts (from the sliding window implementation).
|
||||
// this way, prompts that almost fit the context length can still generate a full
|
||||
// response without a sudden stop from hitting the context limit
|
||||
if (gemma2) {
|
||||
truncate_at = 2 * slot.n_ctx / 3;
|
||||
}
|
||||
|
||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
||||
{
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
const int n_shift = n_left / 2;
|
||||
const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.begin() + slot.params.n_keep + n_erase,
|
||||
prompt_tokens.end());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
LOG_INFO("input truncated", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_shift", n_shift},
|
||||
{"n_erase", n_erase},
|
||||
});
|
||||
slot.truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
@@ -1689,6 +1693,19 @@ struct llama_server_context
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
// Models with sliding window attention do not work with context shifts, so
|
||||
// limit their prediction to the context length
|
||||
if (gemma2) {
|
||||
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
||||
slot.n_predict = limit;
|
||||
slot.params.n_predict = limit;
|
||||
LOG_INFO("model does not support sliding window, limiting generation", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||
{"n_predict", slot.n_predict}
|
||||
});
|
||||
}
|
||||
|
||||
if (!slot.params.cache_prompt)
|
||||
{
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
@@ -2535,7 +2552,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.chat_template = argv[i];
|
||||
}
|
||||
else if (arg == "--override-kv")
|
||||
{
|
||||
@@ -3008,11 +3024,6 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
const auto model_meta = llama.model_meta();
|
||||
|
||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||
// check if the template comes with the model is supported by us
|
||||
llama.validate_model_chat_template(sparams);
|
||||
}
|
||||
|
||||
// Middleware for API key validation
|
||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||
// If API key is not set, skip validation
|
||||
|
||||
Reference in New Issue
Block a user