Compare commits
2 Commits
main
...
jmorganca/
| Author | SHA1 | Date |
|---|---|---|
|
|
f5d663e370 | |
|
|
e5209778f1 |
|
|
@ -1,6 +1,6 @@
|
||||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
|
FETCH_HEAD=1caae7fc6c77551cb1066515e0f414713eebb367
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
int LLAMA_BUILD_NUMBER = 0;
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
|
char const *LLAMA_COMMIT = "1caae7fc6c77551cb1066515e0f414713eebb367";
|
||||||
char const *LLAMA_COMPILER = "";
|
char const *LLAMA_COMPILER = "";
|
||||||
char const *LLAMA_BUILD_TARGET = "";
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
protect **/*.go
|
protect **/*.go
|
||||||
include common/
|
include common/
|
||||||
|
include common/arg.*
|
||||||
|
include common/chat.*
|
||||||
|
include common/chat-parser.*
|
||||||
|
include common/console.*
|
||||||
include common/base64.*
|
include common/base64.*
|
||||||
include common/common.*
|
include common/common.*
|
||||||
include common/json-schema-to-grammar.*
|
include common/json-schema-to-grammar.*
|
||||||
include common/json.*
|
include common/json-partial.*
|
||||||
|
include common/regex-partial.*
|
||||||
include common/log.*
|
include common/log.*
|
||||||
include common/sampling.*
|
include common/sampling.*
|
||||||
include common/stb_image.*
|
include common/stb_image.*
|
||||||
|
|
@ -12,12 +17,23 @@ include include/llama.*
|
||||||
include include/llama-*.*
|
include include/llama-*.*
|
||||||
include tools/
|
include tools/
|
||||||
include tools/mtmd/
|
include tools/mtmd/
|
||||||
|
include tools/mtmd/mtmd.*
|
||||||
|
include tools/mtmd/mtmd-helper.*
|
||||||
|
include tools/mtmd/mtmd-audio.*
|
||||||
include tools/mtmd/clip.*
|
include tools/mtmd/clip.*
|
||||||
include tools/mtmd/clip-impl.*
|
include tools/mtmd/clip-impl.*
|
||||||
include tools/mtmd/llava.*
|
|
||||||
include src/
|
include src/
|
||||||
include src/llama.*
|
include src/llama.*
|
||||||
include src/llama-*.*
|
include src/llama-*.*
|
||||||
include src/unicode-data.*
|
include src/unicode-data.*
|
||||||
include src/unicode.*
|
include src/unicode.*
|
||||||
|
include vendor/
|
||||||
|
include vendor/nlohmann
|
||||||
|
include vendor/nlohmann/*
|
||||||
|
include vendor/miniaudio
|
||||||
|
include vendor/miniaudio/*
|
||||||
|
include vendor/stb
|
||||||
|
include vendor/stb/stb_image.*
|
||||||
|
include vendor/minja
|
||||||
|
include vendor/minja/*
|
||||||
exclude *
|
exclude *
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,89 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//
|
||||||
|
// CLI argument parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_arg {
|
||||||
|
std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
|
||||||
|
std::set<enum llama_example> excludes = {};
|
||||||
|
std::vector<const char *> args;
|
||||||
|
const char * value_hint = nullptr; // help text or example for arg value
|
||||||
|
const char * value_hint_2 = nullptr; // for second arg value
|
||||||
|
const char * env = nullptr;
|
||||||
|
std::string help;
|
||||||
|
bool is_sparam = false; // is current arg a sampling param?
|
||||||
|
void (*handler_void) (common_params & params) = nullptr;
|
||||||
|
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
||||||
|
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
|
||||||
|
void (*handler_int) (common_params & params, int) = nullptr;
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, const std::string &)
|
||||||
|
) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, int)
|
||||||
|
) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params)
|
||||||
|
) : args(args), help(help), handler_void(handler) {}
|
||||||
|
|
||||||
|
// support 2 values for arg
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const char * value_hint_2,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, const std::string &, const std::string &)
|
||||||
|
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
||||||
|
|
||||||
|
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
||||||
|
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
||||||
|
common_arg & set_env(const char * env);
|
||||||
|
common_arg & set_sparam();
|
||||||
|
bool in_example(enum llama_example ex);
|
||||||
|
bool is_exclude(enum llama_example ex);
|
||||||
|
bool get_value_from_env(std::string & output);
|
||||||
|
bool has_value_from_env();
|
||||||
|
std::string to_string();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_params_context {
|
||||||
|
enum llama_example ex = LLAMA_EXAMPLE_COMMON;
|
||||||
|
common_params & params;
|
||||||
|
std::vector<common_arg> options;
|
||||||
|
void(*print_usage)(int, char **) = nullptr;
|
||||||
|
common_params_context(common_params & params) : params(params) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// parse input arguments from CLI
|
||||||
|
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
|
||||||
|
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||||
|
|
||||||
|
// function to be used by test-arg-parser
|
||||||
|
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||||
|
bool common_has_curl();
|
||||||
|
|
||||||
|
struct common_remote_params {
|
||||||
|
std::vector<std::string> headers;
|
||||||
|
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
|
||||||
|
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
|
||||||
|
};
|
||||||
|
// get remote file content, returns <http_code, raw_response_body>
|
||||||
|
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
|
||||||
|
|
@ -0,0 +1,380 @@
|
||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||||
|
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||||
|
{
|
||||||
|
result_.role = "assistant";
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
std::string id = std::to_string(std::rand());
|
||||||
|
if (input.find(id) == std::string::npos) {
|
||||||
|
healing_marker_ = id;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
|
||||||
|
GGML_ASSERT(rng.begin <= rng.end);
|
||||||
|
return input_.substr(rng.begin, rng.end - rng.begin);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::add_content(const std::string &content) {
|
||||||
|
result_.content += content;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
|
||||||
|
result_.reasoning_content += reasoning_content;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
|
||||||
|
if (name.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_tool_call tool_call;
|
||||||
|
tool_call.name = name;
|
||||||
|
tool_call.arguments = arguments;
|
||||||
|
tool_call.id = id;
|
||||||
|
|
||||||
|
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
|
||||||
|
result_.tool_calls.emplace_back(tool_call);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
|
||||||
|
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
|
||||||
|
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
|
||||||
|
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
|
||||||
|
return add_tool_call(name, id, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
||||||
|
for (const auto & item : arr) {
|
||||||
|
if (!add_tool_call(item)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void common_chat_msg_parser::finish() {
|
||||||
|
if (!is_partial_ && pos_ != input_.size()) {
|
||||||
|
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::consume_spaces() {
|
||||||
|
const auto length = input_.size();
|
||||||
|
auto consumed = false;
|
||||||
|
while (pos_ < length && std::isspace(input_[pos_])) {
|
||||||
|
++pos_;
|
||||||
|
consumed = true;
|
||||||
|
}
|
||||||
|
return consumed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
|
||||||
|
auto pos = pos_;
|
||||||
|
for (auto i = 0u; i < literal.size(); ++i) {
|
||||||
|
if (pos >= input_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (input_[pos] != literal[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
++pos;
|
||||||
|
}
|
||||||
|
pos_ = pos;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||||
|
auto idx = input_.find(literal, pos_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
find_regex_result res;
|
||||||
|
res.prelude = input_.substr(pos_, idx - pos_);
|
||||||
|
auto end = idx + literal.size();
|
||||||
|
res.groups.emplace_back(common_string_range{idx, end});
|
||||||
|
move_to(end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
if (is_partial_) {
|
||||||
|
idx = string_find_partial_stop(input_, literal);
|
||||||
|
if (idx != std::string::npos && idx >= pos_) {
|
||||||
|
find_regex_result res;
|
||||||
|
res.prelude = input_.substr(pos_, idx - pos_);
|
||||||
|
auto end = input_.size();
|
||||||
|
res.groups.emplace_back(common_string_range{idx, end});
|
||||||
|
move_to(end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||||
|
if (!try_consume_literal(literal)) {
|
||||||
|
throw common_chat_msg_partial_exception(literal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||||
|
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||||
|
auto stripped_reasoning = string_strip(reasoning);
|
||||||
|
if (stripped_reasoning.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (syntax_.reasoning_in_content) {
|
||||||
|
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
|
||||||
|
add_content(stripped_reasoning);
|
||||||
|
if (closed) {
|
||||||
|
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
add_reasoning_content(stripped_reasoning);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||||
|
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||||
|
if (auto res = try_find_literal(end_think)) {
|
||||||
|
handle_reasoning(res->prelude, /* closed */ true);
|
||||||
|
consume_spaces();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto rest = consume_rest();
|
||||||
|
if (!rest.empty()) {
|
||||||
|
handle_reasoning(rest, /* closed */ !is_partial());
|
||||||
|
}
|
||||||
|
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||||
|
// if (!syntax_.thinking_forced_open) {
|
||||||
|
// throw common_chat_msg_partial_exception(end_think);
|
||||||
|
// }
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_msg_parser::consume_rest() {
|
||||||
|
auto rest = input_.substr(pos_);
|
||||||
|
pos_ = input_.size();
|
||||||
|
return rest;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
|
||||||
|
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||||
|
pos_ = m.groups[0].end;
|
||||||
|
|
||||||
|
if (add_prelude_to_content) {
|
||||||
|
add_content(prelude);
|
||||||
|
}
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||||
|
if (is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
return find_regex_result{prelude, m.groups};
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||||
|
if (auto result = try_consume_regex(regex)) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||||
|
auto m = regex.search(input_, pos_);
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||||
|
if (is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
if (m.groups[0].begin != pos_) {
|
||||||
|
// Didn't match at the current position.
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
pos_ = m.groups[0].end;
|
||||||
|
|
||||||
|
return find_regex_result {
|
||||||
|
/* .prelude = */ "",
|
||||||
|
m.groups,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
|
||||||
|
auto it = input_.cbegin() + pos_;
|
||||||
|
const auto end = input_.cend();
|
||||||
|
common_json result;
|
||||||
|
if (!common_json_parse(it, end, healing_marker_, result)) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
pos_ = std::distance(input_.cbegin(), it);
|
||||||
|
if (result.healing_marker.marker.empty()) {
|
||||||
|
// No healing marker, just return the parsed json
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
if (!is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_json common_chat_msg_parser::consume_json() {
|
||||||
|
if (auto result = try_consume_json()) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths,
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths
|
||||||
|
) {
|
||||||
|
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths,
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths
|
||||||
|
) {
|
||||||
|
auto partial = try_consume_json();
|
||||||
|
if (!partial) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
auto is_arguments_path = [&](const std::vector<std::string> & path) {
|
||||||
|
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
|
||||||
|
};
|
||||||
|
auto is_content_path = [&](const std::vector<std::string> & path) {
|
||||||
|
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (partial->healing_marker.marker.empty()) {
|
||||||
|
if (args_paths.empty()) {
|
||||||
|
// No arguments to dump, and JSON was parsed fully.
|
||||||
|
return consume_json_result {
|
||||||
|
partial->json,
|
||||||
|
/* .is_partial = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (is_arguments_path({})) {
|
||||||
|
// Entire JSON is the arguments and was parsed fully.
|
||||||
|
return consume_json_result {
|
||||||
|
partial->json.dump(),
|
||||||
|
/* .is_partial = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||||
|
|
||||||
|
auto found_healing_marker = false;
|
||||||
|
std::vector<std::string> path;
|
||||||
|
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||||
|
if (is_arguments_path(path)) {
|
||||||
|
auto arguments = j.dump();
|
||||||
|
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||||
|
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
arguments.resize(idx);
|
||||||
|
found_healing_marker = true;
|
||||||
|
}
|
||||||
|
if (arguments == "\"") {
|
||||||
|
// This happens because of completing `:"$magic` after `"arguments"`
|
||||||
|
arguments = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arguments;
|
||||||
|
}
|
||||||
|
if (is_content_path(path)) {
|
||||||
|
if (!j.is_string()) {
|
||||||
|
throw std::runtime_error("Content path must be a string");
|
||||||
|
}
|
||||||
|
std::string str = j;
|
||||||
|
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
str.resize(idx);
|
||||||
|
found_healing_marker = true;
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
if (j.is_object()) {
|
||||||
|
auto obj = json::object();
|
||||||
|
for (const auto & p : j.items()) {
|
||||||
|
const auto & key = p.key();
|
||||||
|
const auto & value = p.value();
|
||||||
|
const std::string key_str = key; // NOLINT
|
||||||
|
auto idx = key_str.find(healing_marker_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
found_healing_marker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
path.push_back(key_str);
|
||||||
|
if (value.is_string()) {
|
||||||
|
const std::string value_str = value;
|
||||||
|
if (value_str.find(healing_marker_) != std::string::npos) {
|
||||||
|
found_healing_marker = true;
|
||||||
|
if (is_content_path(path)) {
|
||||||
|
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
|
||||||
|
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
|
||||||
|
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
obj[key] = value;
|
||||||
|
} else {
|
||||||
|
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||||
|
}
|
||||||
|
path.pop_back();
|
||||||
|
}
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
if (j.is_array()) {
|
||||||
|
auto arr = json::array();
|
||||||
|
for (const auto & value : j) {
|
||||||
|
if (value.is_string()) {
|
||||||
|
std::string str = value;
|
||||||
|
auto idx = str.find(healing_marker_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
// Don't heal array values that aren't in the arguments.
|
||||||
|
found_healing_marker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
arr.push_back(remove_unsupported_healings_and_dump_args(value));
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
return j;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
|
||||||
|
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||||
|
return consume_json_result {
|
||||||
|
cleaned,
|
||||||
|
/* .is_partial = */ found_healing_marker,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_msg_parser {
|
||||||
|
std::string input_;
|
||||||
|
bool is_partial_;
|
||||||
|
common_chat_syntax syntax_;
|
||||||
|
std::string healing_marker_;
|
||||||
|
|
||||||
|
size_t pos_ = 0;
|
||||||
|
common_chat_msg result_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
const std::string & input() const { return input_; }
|
||||||
|
size_t pos() const { return pos_; }
|
||||||
|
const std::string & healing_marker() const { return healing_marker_; }
|
||||||
|
const bool & is_partial() const { return is_partial_; }
|
||||||
|
const common_chat_msg & result() const { return result_; }
|
||||||
|
const common_chat_syntax & syntax() const { return syntax_; }
|
||||||
|
|
||||||
|
void move_to(size_t pos) {
|
||||||
|
if (pos > input_.size()) {
|
||||||
|
throw std::runtime_error("Invalid position!");
|
||||||
|
}
|
||||||
|
pos_ = pos;
|
||||||
|
}
|
||||||
|
void move_back(size_t n) {
|
||||||
|
if (pos_ < n) {
|
||||||
|
throw std::runtime_error("Can't move back that far!");
|
||||||
|
}
|
||||||
|
pos_ -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the substring of the input at the given range
|
||||||
|
std::string str(const common_string_range & rng) const;
|
||||||
|
|
||||||
|
// Appends to the result.content field
|
||||||
|
void add_content(const std::string & content);
|
||||||
|
|
||||||
|
// Appends to the result.reasoning_content field
|
||||||
|
void add_reasoning_content(const std::string & reasoning_content);
|
||||||
|
|
||||||
|
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||||
|
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||||
|
|
||||||
|
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||||
|
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||||
|
|
||||||
|
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||||
|
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||||
|
|
||||||
|
void finish();
|
||||||
|
|
||||||
|
bool consume_spaces();
|
||||||
|
|
||||||
|
void consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||||
|
|
||||||
|
std::string consume_rest();
|
||||||
|
|
||||||
|
struct find_regex_result {
|
||||||
|
std::string prelude;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||||
|
|
||||||
|
bool try_consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||||
|
|
||||||
|
find_regex_result consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<common_json> try_consume_json();
|
||||||
|
common_json consume_json();
|
||||||
|
|
||||||
|
struct consume_json_result {
|
||||||
|
nlohmann::ordered_json value;
|
||||||
|
bool is_partial;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||||
|
|
||||||
|
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||||
|
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||||
|
|
||||||
|
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||||
|
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||||
|
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||||
|
*/
|
||||||
|
consume_json_result consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
};
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,202 @@
|
||||||
|
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <chrono>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct common_chat_templates;
|
||||||
|
|
||||||
|
struct common_chat_tool_call {
|
||||||
|
std::string name;
|
||||||
|
std::string arguments;
|
||||||
|
std::string id;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_tool_call & other) const {
|
||||||
|
return name == other.name && arguments == other.arguments && id == other.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_content_part {
|
||||||
|
std::string type;
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_content_part & other) const {
|
||||||
|
return type == other.type && text == other.text;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
std::vector<common_chat_msg_content_part> content_parts = {};
|
||||||
|
std::vector<common_chat_tool_call> tool_calls = {};
|
||||||
|
std::string reasoning_content;
|
||||||
|
std::string tool_name;
|
||||||
|
std::string tool_call_id;
|
||||||
|
|
||||||
|
template <class T> T to_json_oaicompat() const;
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||||
|
}
|
||||||
|
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||||
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||||
|
if (ids_cache.size() <= i) {
|
||||||
|
auto id = tool_calls[i].id;
|
||||||
|
if (id.empty()) {
|
||||||
|
id = gen_tool_call_id();
|
||||||
|
}
|
||||||
|
ids_cache.push_back(id);
|
||||||
|
}
|
||||||
|
tool_calls[i].id = ids_cache[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool operator==(const common_chat_msg & other) const {
|
||||||
|
return role == other.role
|
||||||
|
&& content == other.content
|
||||||
|
&& content_parts == other.content_parts
|
||||||
|
&& tool_calls == other.tool_calls
|
||||||
|
&& reasoning_content == other.reasoning_content
|
||||||
|
&& tool_name == other.tool_name
|
||||||
|
&& tool_call_id == other.tool_call_id;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_chat_msg & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_diff {
|
||||||
|
std::string reasoning_content_delta;
|
||||||
|
std::string content_delta;
|
||||||
|
size_t tool_call_index = std::string::npos;
|
||||||
|
common_chat_tool_call tool_call_delta;
|
||||||
|
|
||||||
|
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_diff & other) const {
|
||||||
|
return content_delta == other.content_delta
|
||||||
|
&& tool_call_index == other.tool_call_index
|
||||||
|
&& tool_call_delta == other.tool_call_delta;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_tool {
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
std::string parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_tool_choice {
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_format {
|
||||||
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
COMMON_CHAT_FORMAT_GENERIC,
|
||||||
|
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||||
|
|
||||||
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_templates_inputs {
|
||||||
|
std::vector<common_chat_msg> messages;
|
||||||
|
std::string grammar;
|
||||||
|
std::string json_schema;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
|
bool use_jinja = true;
|
||||||
|
// Parameters below only supported when use_jinja is true
|
||||||
|
std::vector<common_chat_tool> tools;
|
||||||
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
|
bool parallel_tool_calls = false;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
bool enable_thinking = true;
|
||||||
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_params {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
std::string prompt;
|
||||||
|
std::string grammar;
|
||||||
|
bool grammar_lazy = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
|
std::vector<std::string> preserved_tokens;
|
||||||
|
std::vector<std::string> additional_stops;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_syntax {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||||
|
bool reasoning_in_content = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
|
bool parse_tool_calls = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||||
|
|
||||||
|
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||||
|
|
||||||
|
common_chat_templates_ptr common_chat_templates_init(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override,
|
||||||
|
const std::string & bos_token_override = "",
|
||||||
|
const std::string & eos_token_override = "");
|
||||||
|
|
||||||
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||||
|
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
struct common_chat_params common_chat_templates_apply(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs);
|
||||||
|
|
||||||
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
|
std::string common_chat_format_single(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
|
const common_chat_msg & new_msg,
|
||||||
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
|
// Returns an example of formatted chat
|
||||||
|
std::string common_chat_format_example(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
|
const char* common_chat_format_name(common_chat_format format);
|
||||||
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||||
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||||
|
|
||||||
|
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||||
|
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||||
|
|
||||||
|
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||||
|
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||||
|
|
||||||
|
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||||
|
|
@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||||
|
|
||||||
DWORD p = NORMAL_PRIORITY_CLASS;
|
DWORD p = NORMAL_PRIORITY_CLASS;
|
||||||
switch (prio) {
|
switch (prio) {
|
||||||
|
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
|
||||||
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
||||||
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
||||||
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
||||||
|
|
@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||||
|
|
||||||
int p = 0;
|
int p = 0;
|
||||||
switch (prio) {
|
switch (prio) {
|
||||||
|
case GGML_SCHED_PRIO_LOW: p = 5; break;
|
||||||
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
||||||
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
||||||
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
||||||
|
|
@ -443,6 +445,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||||
|
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||||
|
}
|
||||||
|
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||||
|
if (!str.empty() && !stop.empty()) {
|
||||||
|
const char text_last_char = str.back();
|
||||||
|
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||||
|
if (stop[char_index] == text_last_char) {
|
||||||
|
const auto current_partial = stop.substr(0, char_index + 1);
|
||||||
|
if (string_ends_with(str, current_partial)) {
|
||||||
|
return str.size() - char_index - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
std::string regex_escape(const std::string & s) {
|
std::string regex_escape(const std::string & s) {
|
||||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||||
return std::regex_replace(s, special_chars, "\\$0");
|
return std::regex_replace(s, special_chars, "\\$0");
|
||||||
|
|
@ -830,7 +851,7 @@ std::string fs_get_cache_directory() {
|
||||||
if (getenv("LLAMA_CACHE")) {
|
if (getenv("LLAMA_CACHE")) {
|
||||||
cache_directory = std::getenv("LLAMA_CACHE");
|
cache_directory = std::getenv("LLAMA_CACHE");
|
||||||
} else {
|
} else {
|
||||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||||
if (std::getenv("XDG_CACHE_HOME")) {
|
if (std::getenv("XDG_CACHE_HOME")) {
|
||||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -884,13 +905,16 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||||
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||||
ok = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
if (!has_eos && !has_sep) {
|
||||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||||
|
ok = false;
|
||||||
|
} else if (!has_eos) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||||
|
} else if (!has_sep) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1083,6 +1107,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||||
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mparams.progress_callback = params.load_progress_callback;
|
||||||
|
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
|
||||||
|
|
||||||
return mparams;
|
return mparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1114,6 +1141,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||||
cparams.flash_attn = params.flash_attn;
|
cparams.flash_attn = params.flash_attn;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
cparams.op_offload = !params.no_op_offload;
|
cparams.op_offload = !params.no_op_offload;
|
||||||
|
cparams.swa_full = params.swa_full;
|
||||||
|
|
||||||
if (params.reranking) {
|
if (params.reranking) {
|
||||||
cparams.embeddings = true;
|
cparams.embeddings = true;
|
||||||
|
|
@ -1306,81 +1334,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// KV cache utils
|
|
||||||
//
|
|
||||||
|
|
||||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
|
|
||||||
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
|
|
||||||
|
|
||||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
|
|
||||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
|
||||||
|
|
||||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
|
||||||
llama_seq_id * cs_curr = view.cells_sequences;
|
|
||||||
|
|
||||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
|
||||||
if (i % row_size == 0) {
|
|
||||||
printf("\n%5d: ", i);
|
|
||||||
}
|
|
||||||
int seq_count = 0;
|
|
||||||
for (int j = 0; j < view.n_seq_max; j++) {
|
|
||||||
if (cs_curr[j] >= 0) { seq_count++; }
|
|
||||||
}
|
|
||||||
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\n=== Done dumping\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
|
||||||
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
|
||||||
|
|
||||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
|
|
||||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, size_t> seqs;
|
|
||||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
|
||||||
llama_seq_id * cs_curr = view.cells_sequences;
|
|
||||||
|
|
||||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
|
||||||
for (int j = 0; j < view.n_seq_max; j++) {
|
|
||||||
if (cs_curr[j] < 0) { continue; }
|
|
||||||
if (seqs.find(cs_curr[j]) == seqs.end()) {
|
|
||||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
|
||||||
const size_t sz = seqs.size();
|
|
||||||
seqs[cs_curr[j]] = sz;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("=== Sequence legend: ");
|
|
||||||
for (const auto & it : seqs) {
|
|
||||||
printf("%zu=%d, ", it.second, it.first);
|
|
||||||
}
|
|
||||||
printf("'+'=other sequence ids");
|
|
||||||
|
|
||||||
c_curr = view.cells;
|
|
||||||
cs_curr = view.cells_sequences;
|
|
||||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
|
||||||
if (i % row_size == 0) {
|
|
||||||
printf("\n%5d: ", i);
|
|
||||||
}
|
|
||||||
for (int j = 0; j < view.n_seq_max; j++) {
|
|
||||||
if (cs_curr[j] >= 0) {
|
|
||||||
const auto & it = seqs.find(cs_curr[j]);
|
|
||||||
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
|
|
||||||
} else {
|
|
||||||
putchar('.');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
putchar(' ');
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\n=== Done dumping\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Embedding utils
|
// Embedding utils
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
// #cgo CXXFLAGS: -std=c++11
|
// #cgo CXXFLAGS: -std=c++17
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include
|
// #cgo CPPFLAGS: -I${SRCDIR}/../include
|
||||||
|
// #cgo CPPFLAGS: -I${SRCDIR}/../vendor
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
||||||
import "C"
|
import "C"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|
@ -75,7 +76,7 @@ enum llama_example {
|
||||||
LLAMA_EXAMPLE_SERVER,
|
LLAMA_EXAMPLE_SERVER,
|
||||||
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
||||||
LLAMA_EXAMPLE_EXPORT_LORA,
|
LLAMA_EXAMPLE_EXPORT_LORA,
|
||||||
LLAMA_EXAMPLE_LLAVA,
|
LLAMA_EXAMPLE_MTMD,
|
||||||
LLAMA_EXAMPLE_LOOKUP,
|
LLAMA_EXAMPLE_LOOKUP,
|
||||||
LLAMA_EXAMPLE_PARALLEL,
|
LLAMA_EXAMPLE_PARALLEL,
|
||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
|
|
@ -114,7 +115,7 @@ enum common_grammar_trigger_type {
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_grammar_trigger {
|
struct common_grammar_trigger {
|
||||||
|
|
@ -214,7 +215,8 @@ struct common_params_vocoder {
|
||||||
|
|
||||||
enum common_reasoning_format {
|
enum common_reasoning_format {
|
||||||
COMMON_REASONING_FORMAT_NONE,
|
COMMON_REASONING_FORMAT_NONE,
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||||
|
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
|
|
@ -290,6 +292,7 @@ struct common_params {
|
||||||
int32_t verbosity = 0;
|
int32_t verbosity = 0;
|
||||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||||
|
bool offline = false;
|
||||||
|
|
||||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||||
|
|
@ -322,13 +325,13 @@ struct common_params {
|
||||||
bool flash_attn = false; // flash attention
|
bool flash_attn = false; // flash attention
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool ctx_shift = true; // context shift on inifinite text generation
|
bool ctx_shift = true; // context shift on inifinite text generation
|
||||||
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool use_mmap = true; // use mmap for faster loads
|
bool use_mmap = true; // use mmap for faster loads
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
bool display_prompt = true; // print prompt before generation
|
bool display_prompt = true; // print prompt before generation
|
||||||
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
|
||||||
bool no_kv_offload = false; // disable KV offloading
|
bool no_kv_offload = false; // disable KV offloading
|
||||||
bool warmup = true; // warmup run
|
bool warmup = true; // warmup run
|
||||||
bool check_tensors = false; // validate tensor data
|
bool check_tensors = false; // validate tensor data
|
||||||
|
|
@ -367,6 +370,8 @@ struct common_params {
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
int reasoning_budget = -1;
|
||||||
|
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
|
|
@ -426,6 +431,11 @@ struct common_params {
|
||||||
|
|
||||||
// common params
|
// common params
|
||||||
std::string out_file; // output filename for all example programs
|
std::string out_file; // output filename for all example programs
|
||||||
|
// optional callback for model loading progress and cancellation:
|
||||||
|
// called with a progress value between 0.0 and 1.0.
|
||||||
|
// return false from callback to abort model loading or true to continue
|
||||||
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
|
void * load_progress_callback_user_data = NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
// call once at the start of a program if it uses libcommon
|
||||||
|
|
@ -503,10 +513,9 @@ static bool string_starts_with(const std::string & str,
|
||||||
return str.rfind(prefix, 0) == 0;
|
return str.rfind(prefix, 0) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool string_ends_with(const std::string & str,
|
// While we wait for C++20's std::string::ends_with...
|
||||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||||
}
|
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
@ -615,16 +624,6 @@ std::string common_detokenize(
|
||||||
const std::vector<llama_token> & tokens,
|
const std::vector<llama_token> & tokens,
|
||||||
bool special = true);
|
bool special = true);
|
||||||
|
|
||||||
//
|
|
||||||
// KV cache utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// Dump the KV cache view with the number of sequences per cell.
|
|
||||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
|
|
||||||
|
|
||||||
// Dump the KV cache view showing individual sequences in each cell (long output).
|
|
||||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Embedding utils
|
// Embedding utils
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,504 @@
|
||||||
|
#include "console.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <io.h>
|
||||||
|
#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
|
||||||
|
#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#include <climits>
|
||||||
|
#include <sys/ioctl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <wchar.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <signal.h>
|
||||||
|
#include <termios.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define ANSI_COLOR_RED "\x1b[31m"
|
||||||
|
#define ANSI_COLOR_GREEN "\x1b[32m"
|
||||||
|
#define ANSI_COLOR_YELLOW "\x1b[33m"
|
||||||
|
#define ANSI_COLOR_BLUE "\x1b[34m"
|
||||||
|
#define ANSI_COLOR_MAGENTA "\x1b[35m"
|
||||||
|
#define ANSI_COLOR_CYAN "\x1b[36m"
|
||||||
|
#define ANSI_COLOR_RESET "\x1b[0m"
|
||||||
|
#define ANSI_BOLD "\x1b[1m"
|
||||||
|
|
||||||
|
namespace console {
|
||||||
|
|
||||||
|
//
|
||||||
|
// Console state
|
||||||
|
//
|
||||||
|
|
||||||
|
static bool advanced_display = false;
|
||||||
|
static bool simple_io = true;
|
||||||
|
static display_t current_display = reset;
|
||||||
|
|
||||||
|
static FILE* out = stdout;
|
||||||
|
|
||||||
|
#if defined (_WIN32)
|
||||||
|
static void* hConsole;
|
||||||
|
#else
|
||||||
|
static FILE* tty = nullptr;
|
||||||
|
static termios initial_state;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// Init and cleanup
|
||||||
|
//
|
||||||
|
|
||||||
|
void init(bool use_simple_io, bool use_advanced_display) {
|
||||||
|
advanced_display = use_advanced_display;
|
||||||
|
simple_io = use_simple_io;
|
||||||
|
#if defined(_WIN32)
|
||||||
|
// Windows-specific console initialization
|
||||||
|
DWORD dwMode = 0;
|
||||||
|
hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
|
||||||
|
if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
|
||||||
|
hConsole = GetStdHandle(STD_ERROR_HANDLE);
|
||||||
|
if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
|
||||||
|
hConsole = nullptr;
|
||||||
|
simple_io = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hConsole) {
|
||||||
|
// Check conditions combined to reduce nesting
|
||||||
|
if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
|
||||||
|
!SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
|
||||||
|
advanced_display = false;
|
||||||
|
}
|
||||||
|
// Set console output codepage to UTF8
|
||||||
|
SetConsoleOutputCP(CP_UTF8);
|
||||||
|
}
|
||||||
|
HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
|
||||||
|
if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
|
||||||
|
// Set console input codepage to UTF16
|
||||||
|
_setmode(_fileno(stdin), _O_WTEXT);
|
||||||
|
|
||||||
|
// Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
|
||||||
|
if (simple_io) {
|
||||||
|
dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
|
||||||
|
} else {
|
||||||
|
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
|
||||||
|
}
|
||||||
|
if (!SetConsoleMode(hConIn, dwMode)) {
|
||||||
|
simple_io = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (simple_io) {
|
||||||
|
_setmode(_fileno(stdin), _O_U8TEXT);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// POSIX-specific console initialization
|
||||||
|
if (!simple_io) {
|
||||||
|
struct termios new_termios;
|
||||||
|
tcgetattr(STDIN_FILENO, &initial_state);
|
||||||
|
new_termios = initial_state;
|
||||||
|
new_termios.c_lflag &= ~(ICANON | ECHO);
|
||||||
|
new_termios.c_cc[VMIN] = 1;
|
||||||
|
new_termios.c_cc[VTIME] = 0;
|
||||||
|
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
|
||||||
|
|
||||||
|
tty = fopen("/dev/tty", "w+");
|
||||||
|
if (tty != nullptr) {
|
||||||
|
out = tty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setlocale(LC_ALL, "");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void cleanup() {
|
||||||
|
// Reset console display
|
||||||
|
set_display(reset);
|
||||||
|
|
||||||
|
#if !defined(_WIN32)
|
||||||
|
// Restore settings on POSIX systems
|
||||||
|
if (!simple_io) {
|
||||||
|
if (tty != nullptr) {
|
||||||
|
out = stdout;
|
||||||
|
fclose(tty);
|
||||||
|
tty = nullptr;
|
||||||
|
}
|
||||||
|
tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Display and IO
|
||||||
|
//
|
||||||
|
|
||||||
|
// Keep track of current display and only emit ANSI code if it changes
|
||||||
|
void set_display(display_t display) {
|
||||||
|
if (advanced_display && current_display != display) {
|
||||||
|
fflush(stdout);
|
||||||
|
switch(display) {
|
||||||
|
case reset:
|
||||||
|
fprintf(out, ANSI_COLOR_RESET);
|
||||||
|
break;
|
||||||
|
case prompt:
|
||||||
|
fprintf(out, ANSI_COLOR_YELLOW);
|
||||||
|
break;
|
||||||
|
case user_input:
|
||||||
|
fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
|
||||||
|
break;
|
||||||
|
case error:
|
||||||
|
fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
|
||||||
|
}
|
||||||
|
current_display = display;
|
||||||
|
fflush(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static char32_t getchar32() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
|
||||||
|
wchar_t high_surrogate = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
INPUT_RECORD record;
|
||||||
|
DWORD count;
|
||||||
|
if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
|
||||||
|
return WEOF;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
|
||||||
|
wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
|
||||||
|
if (wc == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
|
||||||
|
high_surrogate = wc;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
|
||||||
|
if (high_surrogate != 0) { // Check if we have a high surrogate
|
||||||
|
return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
high_surrogate = 0; // Reset the high surrogate
|
||||||
|
return static_cast<char32_t>(wc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
wchar_t wc = getwchar();
|
||||||
|
if (static_cast<wint_t>(wc) == WEOF) {
|
||||||
|
return WEOF;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if WCHAR_MAX == 0xFFFF
|
||||||
|
if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
|
||||||
|
wchar_t low_surrogate = getwchar();
|
||||||
|
if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
|
||||||
|
return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
|
||||||
|
return 0xFFFD; // Return the replacement character U+FFFD
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return static_cast<char32_t>(wc);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void pop_cursor() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (hConsole != NULL) {
|
||||||
|
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
|
||||||
|
GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
|
||||||
|
|
||||||
|
COORD newCursorPosition = bufferInfo.dwCursorPosition;
|
||||||
|
if (newCursorPosition.X == 0) {
|
||||||
|
newCursorPosition.X = bufferInfo.dwSize.X - 1;
|
||||||
|
newCursorPosition.Y -= 1;
|
||||||
|
} else {
|
||||||
|
newCursorPosition.X -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
SetConsoleCursorPosition(hConsole, newCursorPosition);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
putc('\b', out);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int estimateWidth(char32_t codepoint) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
(void)codepoint;
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
|
return wcwidth(codepoint);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
|
||||||
|
if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
|
||||||
|
// go with the default
|
||||||
|
return expectedWidth;
|
||||||
|
}
|
||||||
|
COORD initialPosition = bufferInfo.dwCursorPosition;
|
||||||
|
DWORD nNumberOfChars = length;
|
||||||
|
WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
|
||||||
|
|
||||||
|
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
|
||||||
|
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
|
||||||
|
|
||||||
|
// Figure out our real position if we're in the last column
|
||||||
|
if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
|
||||||
|
DWORD nNumberOfChars;
|
||||||
|
WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
|
||||||
|
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
|
||||||
|
if (width < 0) {
|
||||||
|
width += newBufferInfo.dwSize.X;
|
||||||
|
}
|
||||||
|
return width;
|
||||||
|
#else
|
||||||
|
// We can trust expectedWidth if we've got one
|
||||||
|
if (expectedWidth >= 0 || tty == nullptr) {
|
||||||
|
fwrite(utf8_codepoint, length, 1, out);
|
||||||
|
return expectedWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
fputs("\033[6n", tty); // Query cursor position
|
||||||
|
int x1;
|
||||||
|
int y1;
|
||||||
|
int x2;
|
||||||
|
int y2;
|
||||||
|
int results = 0;
|
||||||
|
results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
|
||||||
|
|
||||||
|
fwrite(utf8_codepoint, length, 1, tty);
|
||||||
|
|
||||||
|
fputs("\033[6n", tty); // Query cursor position
|
||||||
|
results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
|
||||||
|
|
||||||
|
if (results != 4) {
|
||||||
|
return expectedWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
int width = x2 - x1;
|
||||||
|
if (width < 0) {
|
||||||
|
// Calculate the width considering text wrapping
|
||||||
|
struct winsize w;
|
||||||
|
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
|
||||||
|
width += w.ws_col;
|
||||||
|
}
|
||||||
|
return width;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void replace_last(char ch) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
pop_cursor();
|
||||||
|
put_codepoint(&ch, 1, 1);
|
||||||
|
#else
|
||||||
|
fprintf(out, "\b%c", ch);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void append_utf8(char32_t ch, std::string & out) {
|
||||||
|
if (ch <= 0x7F) {
|
||||||
|
out.push_back(static_cast<unsigned char>(ch));
|
||||||
|
} else if (ch <= 0x7FF) {
|
||||||
|
out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
||||||
|
} else if (ch <= 0xFFFF) {
|
||||||
|
out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
||||||
|
} else if (ch <= 0x10FFFF) {
|
||||||
|
out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
|
||||||
|
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
||||||
|
} else {
|
||||||
|
// Invalid Unicode code point
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to remove the last UTF-8 character from a string
|
||||||
|
static void pop_back_utf8_char(std::string & line) {
|
||||||
|
if (line.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t pos = line.length() - 1;
|
||||||
|
|
||||||
|
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
|
||||||
|
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
|
||||||
|
if ((line[pos] & 0xC0) != 0x80) {
|
||||||
|
break; // Found the start of the character
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool readline_advanced(std::string & line, bool multiline_input) {
|
||||||
|
if (out != stdout) {
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
line.clear();
|
||||||
|
std::vector<int> widths;
|
||||||
|
bool is_special_char = false;
|
||||||
|
bool end_of_stream = false;
|
||||||
|
|
||||||
|
char32_t input_char;
|
||||||
|
while (true) {
|
||||||
|
fflush(out); // Ensure all output is displayed before waiting for input
|
||||||
|
input_char = getchar32();
|
||||||
|
|
||||||
|
if (input_char == '\r' || input_char == '\n') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
|
||||||
|
end_of_stream = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_special_char) {
|
||||||
|
set_display(user_input);
|
||||||
|
replace_last(line.back());
|
||||||
|
is_special_char = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_char == '\033') { // Escape sequence
|
||||||
|
char32_t code = getchar32();
|
||||||
|
if (code == '[' || code == 0x1B) {
|
||||||
|
// Discard the rest of the escape sequence
|
||||||
|
while ((code = getchar32()) != (char32_t) WEOF) {
|
||||||
|
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
|
||||||
|
if (!widths.empty()) {
|
||||||
|
int count;
|
||||||
|
do {
|
||||||
|
count = widths.back();
|
||||||
|
widths.pop_back();
|
||||||
|
// Move cursor back, print space, and move cursor back again
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
replace_last(' ');
|
||||||
|
pop_cursor();
|
||||||
|
}
|
||||||
|
pop_back_utf8_char(line);
|
||||||
|
} while (count == 0 && !widths.empty());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int offset = line.length();
|
||||||
|
append_utf8(input_char, line);
|
||||||
|
int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
|
||||||
|
if (width < 0) {
|
||||||
|
width = 0;
|
||||||
|
}
|
||||||
|
widths.push_back(width);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
|
||||||
|
set_display(prompt);
|
||||||
|
replace_last(line.back());
|
||||||
|
is_special_char = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_more = multiline_input;
|
||||||
|
if (is_special_char) {
|
||||||
|
replace_last(' ');
|
||||||
|
pop_cursor();
|
||||||
|
|
||||||
|
char last = line.back();
|
||||||
|
line.pop_back();
|
||||||
|
if (last == '\\') {
|
||||||
|
line += '\n';
|
||||||
|
fputc('\n', out);
|
||||||
|
has_more = !has_more;
|
||||||
|
} else {
|
||||||
|
// llama will just eat the single space, it won't act as a space
|
||||||
|
if (line.length() == 1 && line.back() == ' ') {
|
||||||
|
line.clear();
|
||||||
|
pop_cursor();
|
||||||
|
}
|
||||||
|
has_more = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (end_of_stream) {
|
||||||
|
has_more = false;
|
||||||
|
} else {
|
||||||
|
line += '\n';
|
||||||
|
fputc('\n', out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(out);
|
||||||
|
return has_more;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool readline_simple(std::string & line, bool multiline_input) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
std::wstring wline;
|
||||||
|
if (!std::getline(std::wcin, wline)) {
|
||||||
|
// Input stream is bad or EOF received
|
||||||
|
line.clear();
|
||||||
|
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||||
|
line.resize(size_needed);
|
||||||
|
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||||
|
#else
|
||||||
|
if (!std::getline(std::cin, line)) {
|
||||||
|
// Input stream is bad or EOF received
|
||||||
|
line.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (!line.empty()) {
|
||||||
|
char last = line.back();
|
||||||
|
if (last == '/') { // Always return control on '/' symbol
|
||||||
|
line.pop_back();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (last == '\\') { // '\\' changes the default action
|
||||||
|
line.pop_back();
|
||||||
|
multiline_input = !multiline_input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line += '\n';
|
||||||
|
|
||||||
|
// By default, continue input if multiline_input is set
|
||||||
|
return multiline_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readline(std::string & line, bool multiline_input) {
|
||||||
|
set_display(user_input);
|
||||||
|
|
||||||
|
if (simple_io) {
|
||||||
|
return readline_simple(line, multiline_input);
|
||||||
|
}
|
||||||
|
return readline_advanced(line, multiline_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
// Console functions
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace console {
|
||||||
|
enum display_t {
|
||||||
|
reset = 0,
|
||||||
|
prompt,
|
||||||
|
user_input,
|
||||||
|
error
|
||||||
|
};
|
||||||
|
|
||||||
|
void init(bool use_simple_io, bool use_advanced_display);
|
||||||
|
void cleanup();
|
||||||
|
void set_display(display_t display);
|
||||||
|
bool readline(std::string & line, bool multiline_input);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,256 @@
|
||||||
|
#include "json-partial.h"
|
||||||
|
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum common_json_stack_element_type {
|
||||||
|
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_json_stack_element {
|
||||||
|
common_json_stack_element_type type;
|
||||||
|
std::string key;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
std::string::const_iterator it = input.begin();
|
||||||
|
const auto end = input.end();
|
||||||
|
return common_json_parse(it, end, healing_marker, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||||
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||||
|
std::size_t position;
|
||||||
|
bool found_error;
|
||||||
|
std::string last_token;
|
||||||
|
std::string exception_message;
|
||||||
|
std::vector<common_json_stack_element> stack;
|
||||||
|
|
||||||
|
json_error_locator() : position(0), found_error(false) {}
|
||||||
|
|
||||||
|
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||||
|
this->position = position - 1;
|
||||||
|
this->found_error = true;
|
||||||
|
this->last_token = last_token;
|
||||||
|
this->exception_message = ex.what();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void close_value() {
|
||||||
|
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||||
|
stack.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool null() override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool boolean(bool) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_integer(number_integer_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool string(string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool binary(binary_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_object(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_object() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool key(string_t & key) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_array(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_array() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
json_error_locator err_loc;
|
||||||
|
auto start = it;
|
||||||
|
json::sax_parse(it, end, &err_loc);
|
||||||
|
|
||||||
|
if (err_loc.found_error) {
|
||||||
|
it = start;
|
||||||
|
auto temptative_end = it + err_loc.position;
|
||||||
|
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||||
|
|
||||||
|
auto input = std::string(it, temptative_end);
|
||||||
|
try {
|
||||||
|
out.json = json::parse(input);
|
||||||
|
// out.json = json::parse(it, temptative_end);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & ex) {
|
||||||
|
// No, needs healing.
|
||||||
|
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||||
|
}
|
||||||
|
auto can_parse = [](const std::string & str) {
|
||||||
|
try {
|
||||||
|
auto _ = json::parse(str); // NOLINT
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||||
|
if (last_non_sp_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
auto last_non_sp_char = str[last_non_sp_pos];
|
||||||
|
// Used to detect stops on a number, which may not be complete.
|
||||||
|
auto was_maybe_number = [&]() {
|
||||||
|
if (!str.empty() && std::isspace(str.back())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::isdigit(last_non_sp_char) ||
|
||||||
|
last_non_sp_char == '.' ||
|
||||||
|
last_non_sp_char == 'e' ||
|
||||||
|
last_non_sp_char == 'E' ||
|
||||||
|
last_non_sp_char == '-';
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string closing;
|
||||||
|
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||||
|
auto & el = err_loc.stack[i - 1];
|
||||||
|
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
closing += "}";
|
||||||
|
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
closing += "]";
|
||||||
|
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
throw std::runtime_error("Unexpected stack element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||||
|
|
||||||
|
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
// We're inside an object value
|
||||||
|
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an object value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + ": 1" + closing)) {
|
||||||
|
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||||
|
// Was about to create an object
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an object value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an object value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
// find last :
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to opening : for object value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an array value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an array value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an array value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||||
|
// Had just finished a value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of("[,");
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to last [ or , for array value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||||
|
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\": 1" + closing)) {
|
||||||
|
// Was inside an object key string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||||
|
// Was inside an object key string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||||
|
// fprintf(stderr, "Closing: TODO\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.json = json::parse(it, end);
|
||||||
|
it = end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||||
|
struct common_healing_marker {
|
||||||
|
// Raw marker.
|
||||||
|
std::string marker;
|
||||||
|
|
||||||
|
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||||
|
std::string json_dump_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||||
|
struct common_json {
|
||||||
|
nlohmann::ordered_json json;
|
||||||
|
|
||||||
|
common_healing_marker healing_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||||
|
//
|
||||||
|
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||||
|
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||||
|
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||||
|
//
|
||||||
|
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
||||||
|
|
||||||
|
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
#include <nlohmann/json_fwd.hpp>
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#include <functional>
|
||||||
#include "json.hpp"
|
#include <string>
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||||
bool force_gbnf = false);
|
bool force_gbnf = false);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
#include "regex-partial.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
common_regex::common_regex(const std::string & pattern) :
|
||||||
|
pattern(pattern),
|
||||||
|
rx(pattern),
|
||||||
|
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||||
|
|
||||||
|
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||||
|
std::smatch match;
|
||||||
|
if (pos > input.size()) {
|
||||||
|
throw std::runtime_error("Position out of bounds");
|
||||||
|
}
|
||||||
|
auto start = input.begin() + pos;
|
||||||
|
auto found = as_match
|
||||||
|
? std::regex_match(start, input.end(), match, rx)
|
||||||
|
: std::regex_search(start, input.end(), match, rx);
|
||||||
|
if (found) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||||
|
for (size_t i = 0; i < match.size(); ++i) {
|
||||||
|
auto begin = pos + match.position(i);
|
||||||
|
res.groups.emplace_back(begin, begin + match.length(i));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||||
|
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||||
|
auto group = srmatch[1].str();
|
||||||
|
if (group.length() != 0) {
|
||||||
|
auto it = srmatch[1].second.base();
|
||||||
|
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||||
|
if ((!as_match) || it == input.begin()) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||||
|
const size_t begin = std::distance(input.begin(), it);
|
||||||
|
const size_t end = input.size();
|
||||||
|
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
res.groups.push_back({begin, end});
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||||
|
|
||||||
|
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||||
|
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||||
|
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||||
|
|
||||||
|
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||||
|
- /a|b/ -> (a|b).*
|
||||||
|
- /a*?/ -> error, could match ""
|
||||||
|
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||||
|
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||||
|
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||||
|
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||||
|
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||||
|
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||||
|
|
||||||
|
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||||
|
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||||
|
*/
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||||
|
auto it = pattern.begin();
|
||||||
|
const auto end = pattern.end();
|
||||||
|
|
||||||
|
std::function<std::string()> process = [&]() {
|
||||||
|
std::vector<std::vector<std::string>> alternatives(1);
|
||||||
|
std::vector<std::string> * sequence = &alternatives.back();
|
||||||
|
|
||||||
|
while (it != end) {
|
||||||
|
if (*it == '[') {
|
||||||
|
auto start = it;
|
||||||
|
++it;
|
||||||
|
while (it != end) {
|
||||||
|
if ((*it == '\\') && (++it != end)) {
|
||||||
|
++it;
|
||||||
|
} else if ((it != end) && (*it == ']')) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '[' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
sequence->push_back(std::string(start, it));
|
||||||
|
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Quantifier without preceding element");
|
||||||
|
}
|
||||||
|
sequence->back() += *it;
|
||||||
|
auto is_star = *it == '*';
|
||||||
|
++it;
|
||||||
|
if (is_star) {
|
||||||
|
if (*it == '?') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (*it == '{') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Repetition without preceding element");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto start = it;
|
||||||
|
while (it != end && *it != '}') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '{' in pattern");
|
||||||
|
}
|
||||||
|
auto parts = string_split(std::string(start, it), ",");
|
||||||
|
++it;
|
||||||
|
if (parts.size() > 2) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||||
|
if (s.empty()) {
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
return std::stoi(s);
|
||||||
|
};
|
||||||
|
auto min = parseOptInt(parts[0], 0);
|
||||||
|
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||||
|
if (min && max && *max < *min) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||||
|
auto part = sequence->back();
|
||||||
|
sequence->pop_back();
|
||||||
|
for (int i = 0; i < *min; i++) {
|
||||||
|
sequence->push_back(part);
|
||||||
|
}
|
||||||
|
if (max) {
|
||||||
|
for (int i = *min; i < *max; i++) {
|
||||||
|
sequence->push_back(part + "?");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sequence->push_back(part + "*");
|
||||||
|
}
|
||||||
|
} else if (*it == '(') {
|
||||||
|
++it;
|
||||||
|
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||||
|
it += 2;
|
||||||
|
}
|
||||||
|
auto sub = process();
|
||||||
|
if (*it != ')') {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto & part = sequence->emplace_back("(?:");
|
||||||
|
part += sub;
|
||||||
|
part += ")";
|
||||||
|
} else if (*it == ')') {
|
||||||
|
break;
|
||||||
|
} else if (*it == '|') {
|
||||||
|
++it;
|
||||||
|
alternatives.emplace_back();
|
||||||
|
sequence = &alternatives.back();
|
||||||
|
} else if (*it == '\\' && (++it != end)) {
|
||||||
|
auto str = std::string("\\") + *it;
|
||||||
|
sequence->push_back(str);
|
||||||
|
++it;
|
||||||
|
} else if (it != end) {
|
||||||
|
sequence->push_back(std::string(1, *it));
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||||
|
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||||
|
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||||
|
std::vector<std::string> res_alts;
|
||||||
|
for (const auto & parts : alternatives) {
|
||||||
|
auto & res = res_alts.emplace_back();
|
||||||
|
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||||
|
res += "(?:";
|
||||||
|
}
|
||||||
|
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||||
|
res += *it;
|
||||||
|
if (it != parts.rend() - 1) {
|
||||||
|
res += ")?";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string_join(res_alts, "|");
|
||||||
|
};
|
||||||
|
auto res = process();
|
||||||
|
if (it != end) {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
return "(" + res + ")[\\s\\S]*";
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
enum common_regex_match_type {
|
||||||
|
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_string_range {
|
||||||
|
size_t begin;
|
||||||
|
size_t end;
|
||||||
|
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||||
|
if (begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// prevent default ctor
|
||||||
|
common_string_range() = delete;
|
||||||
|
bool empty() const {
|
||||||
|
return begin == end;
|
||||||
|
}
|
||||||
|
bool operator==(const common_string_range & other) const {
|
||||||
|
return begin == other.begin && end == other.end;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_regex_match {
|
||||||
|
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
|
||||||
|
bool operator==(const common_regex_match & other) const {
|
||||||
|
return type == other.type && groups == other.groups;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_regex_match & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_regex {
|
||||||
|
std::string pattern;
|
||||||
|
std::regex rx;
|
||||||
|
std::regex rx_reversed_partial;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit common_regex(const std::string & pattern);
|
||||||
|
|
||||||
|
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||||
|
|
||||||
|
const std::string & str() const { return pattern; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// For testing only (pretty print of failures).
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||||
|
|
@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
} else {
|
} else {
|
||||||
std::vector<std::string> patterns_at_start;
|
std::vector<std::string> trigger_patterns;
|
||||||
std::vector<std::string> patterns_anywhere;
|
std::vector<std::string> patterns_anywhere;
|
||||||
std::vector<llama_token> trigger_tokens;
|
std::vector<llama_token> trigger_tokens;
|
||||||
for (const auto & trigger : params.grammar_triggers) {
|
for (const auto & trigger : params.grammar_triggers) {
|
||||||
|
|
@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
|
||||||
{
|
{
|
||||||
const auto & pattern = trigger.value;
|
patterns_anywhere.push_back(trigger.value);
|
||||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
break;
|
||||||
|
}
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||||
|
{
|
||||||
|
trigger_patterns.push_back(trigger.value);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||||
|
|
@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> trigger_patterns;
|
|
||||||
if (!patterns_at_start.empty()) {
|
|
||||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
|
||||||
}
|
|
||||||
if (!patterns_anywhere.empty()) {
|
if (!patterns_anywhere.empty()) {
|
||||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,10 @@ extern "C" {
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
struct llama_sampler;
|
struct llama_sampler;
|
||||||
struct llama_kv_cache;
|
|
||||||
|
typedef struct llama_memory_i * llama_memory_t;
|
||||||
|
|
||||||
|
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
|
|
@ -259,9 +262,9 @@ extern "C" {
|
||||||
llama_token * token;
|
llama_token * token;
|
||||||
float * embd;
|
float * embd;
|
||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
|
||||||
int8_t * logits; // TODO: rename this to "output"
|
int8_t * logits; // TODO: rename this to "output"
|
||||||
} llama_batch;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
|
|
@ -345,7 +348,7 @@ extern "C" {
|
||||||
float yarn_beta_fast; // YaRN low correction dim
|
float yarn_beta_fast; // YaRN low correction dim
|
||||||
float yarn_beta_slow; // YaRN high correction dim
|
float yarn_beta_slow; // YaRN high correction dim
|
||||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
void * cb_eval_user_data;
|
void * cb_eval_user_data;
|
||||||
|
|
@ -361,10 +364,13 @@ extern "C" {
|
||||||
|
|
||||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
|
||||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
bool flash_attn; // use flash attention [EXPERIMENTAL]
|
||||||
bool no_perf; // whether to measure performance timings
|
bool no_perf; // measure performance timings
|
||||||
bool op_offload; // whether to offload host tensor operations to device
|
bool op_offload; // offload host tensor operations to device
|
||||||
|
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
|
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||||
};
|
};
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
|
|
@ -470,6 +476,7 @@ extern "C" {
|
||||||
LLAMA_API int64_t llama_time_us(void);
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
LLAMA_API size_t llama_max_devices(void);
|
LLAMA_API size_t llama_max_devices(void);
|
||||||
|
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||||
|
|
||||||
LLAMA_API bool llama_supports_mmap (void);
|
LLAMA_API bool llama_supports_mmap (void);
|
||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
|
|
@ -489,9 +496,11 @@ extern "C" {
|
||||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||||
|
|
||||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||||
|
|
||||||
|
|
@ -500,6 +509,7 @@ extern "C" {
|
||||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||||
|
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||||
|
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||||
|
|
@ -604,78 +614,92 @@ extern "C" {
|
||||||
int32_t il_end);
|
int32_t il_end);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache
|
// Memory
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: start using struct llama_kv_cache
|
// Clear the memory contents
|
||||||
|
LLAMA_API void llama_memory_clear(llama_memory_t mem);
|
||||||
|
|
||||||
// Information associated with an individual cell in the KV cache view.
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
struct llama_kv_cache_view_cell {
|
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
// The position for this cell. Takes KV cache shifts into account.
|
// seq_id < 0 : match any sequence
|
||||||
// May be negative if the cell is not populated.
|
// p0 < 0 : [0, p1]
|
||||||
llama_pos pos;
|
// p1 < 0 : [p0, inf)
|
||||||
};
|
LLAMA_API bool llama_memory_seq_rm(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1);
|
||||||
|
|
||||||
// An updateable view of the KV cache.
|
// Copy all tokens that belong to the specified sequence to another sequence
|
||||||
struct llama_kv_cache_view {
|
// p0 < 0 : [0, p1]
|
||||||
// Number of KV cache cells. This will be the same as the context size.
|
// p1 < 0 : [p0, inf)
|
||||||
int32_t n_cells;
|
LLAMA_API void llama_memory_seq_cp(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1);
|
||||||
|
|
||||||
// Maximum number of sequences that can exist in a cell. It's not an error
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
// if there are more sequences in a cell than this value, however they will
|
LLAMA_API void llama_memory_seq_keep(
|
||||||
// not be visible in the view cells_sequences.
|
llama_memory_t mem,
|
||||||
int32_t n_seq_max;
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Number of tokens in the cache. For example, if there are two populated
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
// p0 < 0 : [0, p1]
|
||||||
// ids then you'll have 3 tokens.
|
// p1 < 0 : [p0, inf)
|
||||||
int32_t token_count;
|
LLAMA_API void llama_memory_seq_add(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta);
|
||||||
|
|
||||||
// Number of populated cache cells.
|
// Integer division of the positions by factor of `d > 1`
|
||||||
int32_t used_cells;
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
|
LLAMA_API void llama_memory_seq_div(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d);
|
||||||
|
|
||||||
// Maximum contiguous empty slots in the cache.
|
// Returns the smallest position present in the memory for the specified sequence
|
||||||
int32_t max_contiguous;
|
// This is typically non-zero only for SWA caches
|
||||||
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||||
|
// Return -1 if the sequence is empty
|
||||||
|
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Index to the start of the max_contiguous slot range. Can be negative
|
// Returns the largest position present in the memory for the specified sequence
|
||||||
// when cache is full.
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||||
int32_t max_contiguous_idx;
|
// Return -1 if the sequence is empty
|
||||||
|
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Information for an individual cell.
|
// Check if the memory supports shifting
|
||||||
struct llama_kv_cache_view_cell * cells;
|
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||||
|
|
||||||
// The sequences for each cell. There will be n_seq_max items per cell.
|
//
|
||||||
llama_seq_id * cells_sequences;
|
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||||
};
|
//
|
||||||
|
|
||||||
// Create an empty KV cache view. (use only for debugging purposes)
|
|
||||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
|
||||||
|
|
||||||
// Free a KV cache view. (use only for debugging purposes)
|
|
||||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
|
||||||
|
|
||||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
|
||||||
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
|
||||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
|
||||||
|
|
||||||
///
|
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||||
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
||||||
|
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_n_tokens instead");
|
|
||||||
|
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
||||||
|
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_used_cells instead");
|
|
||||||
|
|
||||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||||
LLAMA_API void llama_kv_self_clear(
|
LLAMA_API void llama_kv_self_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
|
|
@ -707,7 +731,6 @@ extern "C" {
|
||||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_self_update()
|
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_self_seq_add(
|
LLAMA_API void llama_kv_self_seq_add(
|
||||||
|
|
@ -720,7 +743,6 @@ extern "C" {
|
||||||
// Integer division of the positions by factor of `d > 1`
|
// Integer division of the positions by factor of `d > 1`
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_self_update()
|
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_self_seq_div(
|
LLAMA_API void llama_kv_self_seq_div(
|
||||||
|
|
@ -730,84 +752,40 @@ extern "C" {
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d);
|
int d);
|
||||||
|
|
||||||
|
// Returns the smallest position present in the KV cache for the specified sequence
|
||||||
|
// This is typically non-zero only for SWA caches
|
||||||
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||||
|
// Return -1 if the sequence is empty
|
||||||
|
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Returns the largest position present in the KV cache for the specified sequence
|
// Returns the largest position present in the KV cache for the specified sequence
|
||||||
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||||
|
// Return -1 if the sequence is empty
|
||||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_self_update()
|
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||||
|
|
||||||
// Check if the context supports KV cache shifting
|
// Check if the context supports KV cache shifting
|
||||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||||
|
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
|
||||||
struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_clear instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1),
|
|
||||||
"use llama_kv_self_seq_rm instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1),
|
|
||||||
"use llama_kv_self_seq_cp instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id),
|
|
||||||
"use llama_kv_self_seq_keep instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta),
|
|
||||||
"use llama_kv_self_seq_add instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d),
|
|
||||||
"use llama_kv_self_seq_div instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_seq_id seq_id),
|
|
||||||
"use llama_kv_self_seq_pos_max instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_defrag instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_can_shift instead");
|
|
||||||
|
|
||||||
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
|
||||||
"use llama_kv_self_update instead");
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
//
|
//
|
||||||
|
|
||||||
// Returns the *actual* size in bytes of the state
|
// Returns the *actual* size in bytes of the state
|
||||||
// (logits, embedding and kv_cache)
|
// (logits, embedding and memory)
|
||||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||||
|
|
@ -863,12 +841,12 @@ extern "C" {
|
||||||
size_t n_token_count),
|
size_t n_token_count),
|
||||||
"use llama_state_save_file instead");
|
"use llama_state_save_file instead");
|
||||||
|
|
||||||
// Get the exact size needed to copy the KV cache of a single sequence
|
// Get the exact size needed to copy the state of a single sequence
|
||||||
LLAMA_API size_t llama_state_seq_get_size(
|
LLAMA_API size_t llama_state_seq_get_size(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Copy the KV cache of a single sequence into the specified buffer
|
// Copy the state of a single sequence into the specified buffer
|
||||||
LLAMA_API size_t llama_state_seq_get_data(
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst,
|
uint8_t * dst,
|
||||||
|
|
@ -934,18 +912,21 @@ extern "C" {
|
||||||
// For encode-decoder contexts, processes the batch using the encoder.
|
// For encode-decoder contexts, processes the batch using the encoder.
|
||||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the memory state is restored to the state before this call
|
||||||
LLAMA_API int32_t llama_encode(
|
LLAMA_API int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch batch);
|
||||||
|
|
||||||
// Process a batch of tokens.
|
// Process a batch of tokens.
|
||||||
// Requires KV cache.
|
// Requires the context to have a memory.
|
||||||
// For encode-decoder contexts, processes the batch using the decoder.
|
// For encode-decoder contexts, processes the batch using the decoder.
|
||||||
// Positive return values does not mean a fatal error, but rather a warning.
|
// Positive return values does not mean a fatal error, but rather a warning.
|
||||||
// 0 - success
|
// Upon non-zero return values, the memory state is restored to the state before this call
|
||||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
// 0 - success
|
||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||||
|
// 2 - aborted
|
||||||
|
// -1 - invalid input batch
|
||||||
|
// < -1 - error
|
||||||
LLAMA_API int32_t llama_decode(
|
LLAMA_API int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch batch);
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
||||||
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
||||||
|
|
||||||
|
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||||
|
|
||||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||||
|
|
@ -450,6 +452,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
|
@ -1483,6 +1486,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,8 @@ enum llm_kv {
|
||||||
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
||||||
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
||||||
|
|
||||||
|
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
||||||
|
|
||||||
// deprecated:
|
// deprecated:
|
||||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
@ -14,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
|
||||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
udatas.push_back({});
|
||||||
ubatch_pos.resize(n_ubatch);
|
|
||||||
ubatch_n_seq_id.resize(n_ubatch);
|
auto & udata = udatas.back();
|
||||||
ubatch_seq_id.resize(n_ubatch);
|
|
||||||
ubatch_output.resize(n_ubatch);
|
udata.token.resize(!has_embd ? n_ubatch : 0);
|
||||||
|
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||||
|
udata.pos.resize(n_ubatch);
|
||||||
|
udata.n_seq_id.resize(n_ubatch);
|
||||||
|
udata.seq_id.resize(n_ubatch);
|
||||||
|
udata.output.resize(n_ubatch);
|
||||||
|
|
||||||
llama_ubatch ubatch = {
|
llama_ubatch ubatch = {
|
||||||
/*equal_seqs =*/ true,
|
/*equal_seqs =*/ true,
|
||||||
/*n_tokens =*/ 0,
|
/*n_tokens =*/ 0,
|
||||||
/*n_seq_tokens =*/ 0,
|
/*n_seq_tokens =*/ 0,
|
||||||
/*n_seqs =*/ 0,
|
/*n_seqs =*/ 0,
|
||||||
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
||||||
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
||||||
/*pos =*/ ubatch_pos.data(),
|
/*pos =*/ udata.pos.data(),
|
||||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
/*n_seq_id =*/ udata.n_seq_id.data(),
|
||||||
/*seq_id =*/ ubatch_seq_id.data(),
|
/*seq_id =*/ udata.seq_id.data(),
|
||||||
/*output =*/ ubatch_output.data(),
|
/*output =*/ udata.output.data(),
|
||||||
};
|
};
|
||||||
|
|
||||||
return ubatch;
|
return ubatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -281,9 +289,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
||||||
batch = in_batch;
|
batch = in_batch;
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
if (!batch.pos) {
|
if (!batch.pos) {
|
||||||
|
assert(p0 >= 0);
|
||||||
pos.resize(batch.n_tokens);
|
pos.resize(batch.n_tokens);
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
pos[i] = i + p0;
|
pos[i] = p0 + i;
|
||||||
}
|
}
|
||||||
batch.pos = pos.data();
|
batch.pos = pos.data();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,15 +11,15 @@ struct llama_ubatch {
|
||||||
bool equal_seqs;
|
bool equal_seqs;
|
||||||
// TODO: whole_seqs for embeddings?
|
// TODO: whole_seqs for embeddings?
|
||||||
|
|
||||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||||
uint32_t n_seq_tokens; // tokens per sequence
|
uint32_t n_seq_tokens; // tokens per sequence
|
||||||
uint32_t n_seqs;
|
uint32_t n_seqs;
|
||||||
|
|
||||||
llama_token * token; // [n_tokens]
|
llama_token * token; // [n_tokens]
|
||||||
float * embd; // [n_embd, n_tokens]
|
float * embd; // [n_embd, n_tokens]
|
||||||
llama_pos * pos; // [n_tokens]
|
llama_pos * pos; // [n_tokens]
|
||||||
int32_t * n_seq_id; // [n_seqs]
|
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
|
||||||
llama_seq_id ** seq_id; // [n_seqs]
|
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
|
||||||
int8_t * output; // [n_tokens]
|
int8_t * output; // [n_tokens]
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -49,13 +49,18 @@ struct llama_sbatch {
|
||||||
|
|
||||||
const llama_batch * batch = nullptr;
|
const llama_batch * batch = nullptr;
|
||||||
|
|
||||||
// buffers for the ubatch
|
// buffers for the ubatches
|
||||||
std::vector<llama_token> ubatch_token;
|
// TODO: very hacky, this needs a complete rework
|
||||||
std::vector<float> ubatch_embd;
|
struct ubatch_data {
|
||||||
std::vector<llama_pos> ubatch_pos;
|
std::vector<llama_token> token;
|
||||||
std::vector<int32_t> ubatch_n_seq_id;
|
std::vector<float> embd;
|
||||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int8_t> ubatch_output;
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id *> seq_id;
|
||||||
|
std::vector<int8_t> output;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<ubatch_data> udatas;
|
||||||
|
|
||||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -5,7 +5,6 @@
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
#include "llama-graph.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-adapter.h"
|
#include "llama-adapter.h"
|
||||||
#include "llama-kv-cache.h"
|
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
#include "ggml-opt.h"
|
#include "ggml-opt.h"
|
||||||
|
|
@ -14,11 +13,13 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_kv_cache;
|
|
||||||
|
|
||||||
class llama_io_read_i;
|
class llama_io_read_i;
|
||||||
class llama_io_write_i;
|
class llama_io_write_i;
|
||||||
|
|
||||||
|
struct llama_memory_i;
|
||||||
|
struct llama_memory_state_i;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
// init scheduler and compute buffers, reserve worst-case graphs
|
// init scheduler and compute buffers, reserve worst-case graphs
|
||||||
llama_context(
|
llama_context(
|
||||||
|
|
@ -45,10 +46,12 @@ struct llama_context {
|
||||||
uint32_t n_threads() const;
|
uint32_t n_threads() const;
|
||||||
uint32_t n_threads_batch() const;
|
uint32_t n_threads_batch() const;
|
||||||
|
|
||||||
llama_kv_cache * get_kv_self();
|
llama_memory_t get_memory() const;
|
||||||
const llama_kv_cache * get_kv_self() const;
|
|
||||||
|
|
||||||
void kv_self_update();
|
// return true of the KV cache was updated
|
||||||
|
// TODO: remove
|
||||||
|
bool kv_self_update(bool optimize);
|
||||||
|
void kv_self_defrag_sched();
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type() const;
|
enum llama_pooling_type pooling_type() const;
|
||||||
|
|
||||||
|
|
@ -89,6 +92,16 @@ struct llama_context {
|
||||||
int32_t il_start,
|
int32_t il_start,
|
||||||
int32_t il_end);
|
int32_t il_end);
|
||||||
|
|
||||||
|
// process a single ubatch with a specific graph type
|
||||||
|
// if memory_state is provided, it will be applied first to the context's memory
|
||||||
|
// ret contains the status of the graph computation
|
||||||
|
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||||
|
llm_graph_result_ptr process_ubatch(
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
llm_graph_type gtype,
|
||||||
|
llama_memory_state_i * mstate,
|
||||||
|
ggml_status & ret);
|
||||||
|
|
||||||
int encode(llama_batch & inp_batch);
|
int encode(llama_batch & inp_batch);
|
||||||
int decode(llama_batch & inp_batch);
|
int decode(llama_batch & inp_batch);
|
||||||
|
|
||||||
|
|
@ -181,16 +194,18 @@ public:
|
||||||
ggml_cgraph * graph_init();
|
ggml_cgraph * graph_init();
|
||||||
|
|
||||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||||
ggml_status graph_compute(
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
ggml_cgraph * gf,
|
|
||||||
bool batched);
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
|
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_result_ptr graph_build(
|
llm_graph_result_ptr graph_build(
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype);
|
llm_graph_type gtype,
|
||||||
|
const llama_memory_state_i * mstate);
|
||||||
|
|
||||||
llm_graph_cb graph_get_cb() const;
|
llm_graph_cb graph_get_cb() const;
|
||||||
|
|
||||||
|
|
@ -215,6 +230,9 @@ private:
|
||||||
|
|
||||||
std::unique_ptr<llama_memory_i> memory;
|
std::unique_ptr<llama_memory_i> memory;
|
||||||
|
|
||||||
|
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||||
|
bool memory_force_optimize = false;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||||
size_t logits_size = 0; // capacity (of floats) for logits
|
size_t logits_size = 0; // capacity (of floats) for logits
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
|
|
|
||||||
|
|
@ -1 +1,5 @@
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
|
|
||||||
|
size_t llama_max_parallel_sequences(void) {
|
||||||
|
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
|
|
|
||||||
|
|
@ -1186,8 +1186,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
// get from the first match to the end of the string
|
// get from the first matched capturing group to the end of the string
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
size_t start = std::string::npos;
|
||||||
|
for (auto i = 1u; i < match.size(); i++) {
|
||||||
|
if (match.length(i) > 0) {
|
||||||
|
start = match.position(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
start = match.position(0);
|
||||||
|
}
|
||||||
|
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
|
|
|
||||||
|
|
@ -3,39 +3,15 @@
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
#include "llama-kv-cache.h"
|
|
||||||
|
#include "llama-kv-cache-unified.h"
|
||||||
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
|
#include "llama-kv-cache-recurrent.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
|
||||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
|
||||||
const int64_t max_distance = 128;
|
|
||||||
|
|
||||||
if (bidirectional) {
|
|
||||||
n_buckets >>= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t max_exact = n_buckets >> 1;
|
|
||||||
|
|
||||||
int32_t relative_position = x - y;
|
|
||||||
int32_t relative_bucket = 0;
|
|
||||||
|
|
||||||
if (bidirectional) {
|
|
||||||
relative_bucket += (relative_position > 0) * n_buckets;
|
|
||||||
relative_position = abs(relative_position);
|
|
||||||
} else {
|
|
||||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
|
||||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
|
||||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
|
||||||
|
|
||||||
return relative_bucket;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||||
if (ubatch->token) {
|
if (ubatch->token) {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
@ -110,22 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||||
if (pos_bucket) {
|
if (pos_bucket) {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
|
||||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
||||||
|
|
||||||
int32_t * data = (int32_t *) pos_bucket->data;
|
|
||||||
|
|
||||||
const int64_t n_kv = kv_self->n;
|
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -276,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
const int64_t n_kv = kv_self->n;
|
const int64_t n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
|
|
@ -284,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
data[i] = kv_self->s_copy(i);
|
data[i] = kv_state->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -292,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||||
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
const int64_t n_kv = kv_self->n;
|
const int64_t n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
if (s_mask) {
|
if (s_mask) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
||||||
|
|
@ -300,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
// clear unused states
|
// clear unused states
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
data[i] = kv_self->s_mask(i);
|
data[i] = kv_state->s_mask(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -403,99 +364,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask || self_kq_mask_swa) {
|
if (self_kq_mask) {
|
||||||
const int64_t n_kv = kv_self->n;
|
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
}
|
||||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
}
|
||||||
const int64_t n_seqs = ubatch->n_seqs;
|
|
||||||
|
|
||||||
float * data = nullptr;
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
float * data_swa = nullptr;
|
if (self_kq_mask) {
|
||||||
|
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
}
|
||||||
|
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask_swa) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
data = (float *) self_kq_mask->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (self_kq_mask_swa) {
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
|
||||||
data_swa = (float *) self_kq_mask_swa->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
||||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
|
||||||
// Causal mask:
|
|
||||||
// xxx-------
|
|
||||||
// xxxx------
|
|
||||||
// xxxxx-----
|
|
||||||
// Non-causal mask:
|
|
||||||
// xxxxx-----
|
|
||||||
// xxxxx-----
|
|
||||||
// xxxxx-----
|
|
||||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
|
||||||
for (int s = 0; s < n_seqs; ++s) {
|
|
||||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
||||||
|
|
||||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
||||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
|
||||||
float f;
|
|
||||||
// mask the token if:
|
|
||||||
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
|
||||||
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
|
||||||
) {
|
|
||||||
f = -INFINITY;
|
|
||||||
} else {
|
|
||||||
if (hparams.use_alibi) {
|
|
||||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
|
||||||
} else {
|
|
||||||
f = 0.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data) {
|
|
||||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// may need to cut off old tokens for sliding window
|
|
||||||
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
|
||||||
if (data_swa) {
|
|
||||||
if (hparams.n_attn_chunk) {
|
|
||||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
|
||||||
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
|
||||||
f = -INFINITY;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
|
||||||
f = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// mask padded tokens
|
|
||||||
if (data) {
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
||||||
for (int j = 0; j < n_kv; ++j) {
|
|
||||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// mask padded tokens
|
|
||||||
if (data_swa) {
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
||||||
for (int j = 0; j < n_kv; ++j) {
|
|
||||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -545,7 +425,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||||
n_layer (hparams.n_layer),
|
n_layer (hparams.n_layer),
|
||||||
n_rot (hparams.n_rot),
|
n_rot (hparams.n_rot),
|
||||||
n_ctx (cparams.n_ctx),
|
n_ctx (cparams.n_ctx),
|
||||||
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
|
||||||
n_head (hparams.n_head()),
|
n_head (hparams.n_head()),
|
||||||
n_head_kv (hparams.n_head_kv()),
|
n_head_kv (hparams.n_head_kv()),
|
||||||
n_embd_head_k (hparams.n_embd_head_k),
|
n_embd_head_k (hparams.n_embd_head_k),
|
||||||
|
|
@ -572,14 +451,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||||
backend_cpu (params.backend_cpu),
|
backend_cpu (params.backend_cpu),
|
||||||
cvec (params.cvec),
|
cvec (params.cvec),
|
||||||
loras (params.loras),
|
loras (params.loras),
|
||||||
memory (params.memory),
|
mstate (params.mstate),
|
||||||
cross (params.cross),
|
cross (params.cross),
|
||||||
cb_func (params.cb),
|
cb_func (params.cb),
|
||||||
res (std::make_unique<llm_graph_result>()) {
|
res (std::make_unique<llm_graph_result>()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t llm_graph_context::n_pos_per_embd() const {
|
int64_t llm_graph_context::n_pos_per_embd() const {
|
||||||
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
||||||
|
|
@ -890,9 +769,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
||||||
|
|
||||||
if (weight_before_ffn) {
|
if (weight_before_ffn) {
|
||||||
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
||||||
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
||||||
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
|
||||||
cur = ggml_mul(ctx0, repeated, weights);
|
cur = ggml_mul(ctx0, repeated, weights);
|
||||||
cb(cur, "ffn_moe_weighted", il);
|
cb(cur, "ffn_moe_weighted", il);
|
||||||
}
|
}
|
||||||
|
|
@ -1078,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->s_copy;
|
auto & cur = inp->s_copy;
|
||||||
|
|
||||||
|
|
@ -1095,11 +973,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->s_mask;
|
auto & cur = inp->s_mask;
|
||||||
|
|
||||||
|
|
@ -1149,11 +1027,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->pos_bucket;
|
auto & cur = inp->pos_bucket;
|
||||||
|
|
||||||
|
|
@ -1188,16 +1066,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
ggml_tensor * kq_mask,
|
ggml_tensor * kq_mask,
|
||||||
ggml_tensor * v_mla,
|
ggml_tensor * v_mla,
|
||||||
bool v_trans,
|
|
||||||
float kq_scale) const {
|
float kq_scale) const {
|
||||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const bool v_trans = v->nb[1] > v->nb[2];
|
||||||
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
|
|
||||||
//const int64_t n_head = hparams.n_head(il);
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||||
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||||
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||||
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
||||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
||||||
|
|
||||||
const auto n_tokens = q->ne[1];
|
const auto n_tokens = q->ne[1];
|
||||||
const auto n_head = q->ne[2];
|
const auto n_head = q->ne[2];
|
||||||
|
|
@ -1336,17 +1210,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
ggml_tensor * q = q_cur;
|
||||||
//cb(q, "q", il);
|
ggml_tensor * k = k_cur;
|
||||||
|
ggml_tensor * v = v_cur;
|
||||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
|
||||||
//cb(k, "k", il);
|
|
||||||
|
|
||||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
||||||
//cb(k, "v", il);
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
|
@ -1365,26 +1233,20 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
if (hparams.n_swa_pattern > 1) {
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
GGML_ASSERT(hparams.n_swa > 0);
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||||
|
|
@ -1408,82 +1270,105 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||||
const auto & n_ctx = cparams.n_ctx;
|
|
||||||
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
|
|
||||||
const auto n_tokens = q_cur->ne[2];
|
|
||||||
|
|
||||||
const bool v_trans = !cparams.flash_attn;
|
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
const auto kv_head = kv_self->head;
|
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||||
GGML_ASSERT(kv_self->size == n_ctx);
|
|
||||||
|
|
||||||
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
|
|
||||||
//cb(k_cache_view, "k_cache_view", il);
|
|
||||||
|
|
||||||
// note: storing RoPE-ed version of K in the KV cache
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
|
||||||
|
|
||||||
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
|
||||||
|
|
||||||
ggml_tensor * v_cache_view = nullptr;
|
|
||||||
|
|
||||||
if (!v_trans) {
|
|
||||||
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
|
|
||||||
} else {
|
|
||||||
// note: the V cache is transposed when not using flash attention
|
|
||||||
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
|
||||||
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
|
|
||||||
(kv_head)*ggml_element_size(kv_self->v_l[il]));
|
|
||||||
|
|
||||||
v_cur = ggml_transpose(ctx0, v_cur);
|
|
||||||
}
|
|
||||||
//cb(v_cache_view, "v_cache_view", il);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
|
ggml_tensor * q = q_cur;
|
||||||
|
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||||
|
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
|
if (wo) {
|
||||||
|
cur = build_lora_mm(wo, cur);
|
||||||
|
if (arch == LLM_ARCH_GLM4) {
|
||||||
|
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||||
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
cur = ggml_add(ctx0, cur, wo_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||||
|
|
||||||
|
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla,
|
||||||
|
float kq_scale,
|
||||||
|
int il) const {
|
||||||
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
|
// by doing so, the number of splits in the graph is reduced
|
||||||
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
|
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||||
|
|
||||||
const bool is_swa = hparams.is_swa(il);
|
const bool is_swa = hparams.is_swa(il);
|
||||||
|
|
||||||
|
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
||||||
|
|
||||||
|
// store to KV cache
|
||||||
|
{
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||||
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
ggml_tensor * q = q_cur;
|
||||||
|
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||||
|
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||||
|
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
|
|
||||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
||||||
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
||||||
|
|
||||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
|
||||||
//cb(q, "q", il);
|
|
||||||
|
|
||||||
ggml_tensor * k =
|
|
||||||
ggml_view_3d(ctx0, kv_self->k_l[il],
|
|
||||||
n_embd_head_k, n_kv, n_head_kv,
|
|
||||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
||||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
|
||||||
0);
|
|
||||||
//cb(k, "k", il);
|
|
||||||
|
|
||||||
ggml_tensor * v = !v_trans ?
|
|
||||||
ggml_view_3d(ctx0, kv_self->v_l[il],
|
|
||||||
n_embd_head_v, n_kv, n_head_kv,
|
|
||||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
||||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
|
||||||
0) :
|
|
||||||
ggml_view_3d(ctx0, kv_self->v_l[il],
|
|
||||||
n_kv, n_embd_head_v, n_head_kv,
|
|
||||||
ggml_element_size(kv_self->v_l[il])*n_ctx,
|
|
||||||
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
|
||||||
0);
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
|
@ -1534,17 +1419,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask_cross();
|
const auto & kq_mask = inp->get_kq_mask_cross();
|
||||||
|
|
||||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
ggml_tensor * q = q_cur;
|
||||||
//cb(q, "q", il);
|
ggml_tensor * k = k_cur;
|
||||||
|
ggml_tensor * v = v_cur;
|
||||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
|
||||||
//cb(k, "k", il);
|
|
||||||
|
|
||||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
||||||
//cb(k, "v", il);
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
|
@ -1569,12 +1448,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||||
ggml_tensor * state_mask,
|
ggml_tensor * state_mask,
|
||||||
int32_t n_state,
|
int32_t n_state,
|
||||||
int32_t n_seqs) const {
|
int32_t n_seqs) const {
|
||||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto n_kv = kv_self->n;
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
const auto kv_head = kv_self->head;
|
const auto kv_head = kv_state->get_head();
|
||||||
|
|
||||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
||||||
|
|
||||||
// copy states
|
// copy states
|
||||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||||
|
|
@ -1601,13 +1480,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
ggml_tensor * state_mask,
|
ggml_tensor * state_mask,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
ggml_tensor * token_shift_all = kv_self->k_l[il];
|
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_copy_mask_state(
|
ggml_tensor * token_shift = build_copy_mask_state(
|
||||||
gf, token_shift_all, state_copy, state_mask,
|
gf, token_shift_all, state_copy, state_mask,
|
||||||
|
|
@ -1622,19 +1501,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||||
ggml_tensor * token_shift,
|
ggml_tensor * token_shift,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
const auto kv_head = kv_self->head;
|
const auto kv_head = kv_state->get_head();
|
||||||
|
|
||||||
return ggml_cpy(
|
return ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||||
ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
|
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1685,20 +1564,25 @@ void llm_graph_context::build_pooling(
|
||||||
ggml_tensor * inp_cls = build_inp_cls();
|
ggml_tensor * inp_cls = build_inp_cls();
|
||||||
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
||||||
|
|
||||||
// classification head
|
if (cls != nullptr && cls_b != nullptr) {
|
||||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
// classification head
|
||||||
GGML_ASSERT(cls != nullptr);
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||||
GGML_ASSERT(cls_b != nullptr);
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
|
||||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||||
cur = ggml_tanh(ctx0, cur);
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
||||||
|
if (cls_out) {
|
||||||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
GGML_ASSERT(cls_out_b != nullptr);
|
||||||
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
||||||
if (cls_out) {
|
}
|
||||||
|
} else if (cls_out) {
|
||||||
|
// Single layer classification head (direct projection)
|
||||||
|
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
||||||
GGML_ASSERT(cls_out_b != nullptr);
|
GGML_ASSERT(cls_out_b != nullptr);
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
|
||||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
} else {
|
||||||
|
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|
@ -1712,3 +1596,30 @@ void llm_graph_context::build_pooling(
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||||
|
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||||
|
const int64_t max_distance = 128;
|
||||||
|
|
||||||
|
if (bidirectional) {
|
||||||
|
n_buckets >>= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t max_exact = n_buckets >> 1;
|
||||||
|
|
||||||
|
int32_t relative_position = x - y;
|
||||||
|
int32_t relative_bucket = 0;
|
||||||
|
|
||||||
|
if (bidirectional) {
|
||||||
|
relative_bucket += (relative_position > 0) * n_buckets;
|
||||||
|
relative_position = abs(relative_position);
|
||||||
|
} else {
|
||||||
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||||
|
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||||
|
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||||
|
|
||||||
|
return relative_bucket;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,11 @@ struct ggml_tensor;
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
struct llama_cparams;
|
struct llama_cparams;
|
||||||
|
|
||||||
class llama_memory_i;
|
struct llama_memory_state_i;
|
||||||
class llama_kv_cache_unified;
|
|
||||||
class llama_kv_cache_recurrent;
|
class llama_kv_cache_unified_state;
|
||||||
|
class llama_kv_cache_unified_iswa_state;
|
||||||
|
class llama_kv_cache_recurrent_state;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
|
|
@ -132,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_pos_bucket_kv(
|
llm_graph_input_pos_bucket_kv(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
||||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
@ -140,7 +142,7 @@ public:
|
||||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_kv_cache_unified * kv_self;
|
const llama_kv_cache_unified_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||||
|
|
@ -187,26 +189,26 @@ public:
|
||||||
|
|
||||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||||
virtual ~llm_graph_input_s_copy() = default;
|
virtual ~llm_graph_input_s_copy() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [kv_size]
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
const llama_kv_cache_recurrent * kv_self;
|
const llama_kv_cache_recurrent_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||||
virtual ~llm_graph_input_s_mask() = default;
|
virtual ~llm_graph_input_s_mask() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||||
|
|
||||||
const llama_kv_cache_recurrent * kv_self;
|
const llama_kv_cache_recurrent_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
|
|
@ -246,15 +248,40 @@ public:
|
||||||
llm_graph_input_attn_kv_unified(
|
llm_graph_input_attn_kv_unified(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache_unified * kv_self) :
|
const llama_kv_cache_unified_state * kv_state) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_self(kv_self) {
|
kv_state(kv_state) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified() = default;
|
~llm_graph_input_attn_kv_unified() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * kv_state;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_kv_unified_iswa(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_unified_iswa_state * kv_state) :
|
||||||
|
hparams(hparams),
|
||||||
|
cparams(cparams),
|
||||||
|
kv_state(kv_state) {
|
||||||
|
}
|
||||||
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
|
|
@ -266,7 +293,7 @@ public:
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_kv_cache_unified * kv_self;
|
const llama_kv_cache_unified_iswa_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||||
|
|
@ -357,10 +384,10 @@ struct llm_graph_params {
|
||||||
ggml_backend_sched_t sched;
|
ggml_backend_sched_t sched;
|
||||||
ggml_backend_t backend_cpu;
|
ggml_backend_t backend_cpu;
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_i * memory;
|
const llama_memory_state_i * mstate;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
int32_t n_outputs;
|
int32_t n_outputs;
|
||||||
|
|
||||||
|
|
@ -378,7 +405,6 @@ struct llm_graph_context {
|
||||||
const int64_t n_layer;
|
const int64_t n_layer;
|
||||||
const int64_t n_rot;
|
const int64_t n_rot;
|
||||||
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
||||||
const int64_t n_ctx_per_seq;
|
|
||||||
const int64_t n_head;
|
const int64_t n_head;
|
||||||
const int64_t n_head_kv;
|
const int64_t n_head_kv;
|
||||||
const int64_t n_embd_head_k;
|
const int64_t n_embd_head_k;
|
||||||
|
|
@ -410,10 +436,10 @@ struct llm_graph_context {
|
||||||
|
|
||||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_i * memory;
|
const llama_memory_state_i * mstate;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
const llm_graph_cb & cb_func;
|
const llm_graph_cb & cb_func;
|
||||||
|
|
||||||
|
|
@ -507,13 +533,12 @@ struct llm_graph_context {
|
||||||
|
|
||||||
ggml_tensor * build_attn_mha(
|
ggml_tensor * build_attn_mha(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
ggml_tensor * kq_mask,
|
ggml_tensor * kq_mask,
|
||||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
bool v_trans,
|
|
||||||
float kq_scale) const;
|
float kq_scale) const;
|
||||||
|
|
||||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||||
|
|
@ -546,6 +571,21 @@ struct llm_graph_context {
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||||
|
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
|
|
@ -596,3 +636,6 @@ struct llm_graph_context {
|
||||||
ggml_tensor * cls_out,
|
ggml_tensor * cls_out,
|
||||||
ggml_tensor * cls_out_b) const;
|
ggml_tensor * cls_out_b) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: better name
|
||||||
|
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,22 @@
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::is_swa_any() const {
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
if (swa_layers[il]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return n_head_arr[il];
|
return n_head_arr[il];
|
||||||
|
|
@ -80,7 +96,7 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
|
||||||
|
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
|
return swa_layers[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_swa_type {
|
||||||
|
LLAMA_SWA_TYPE_NONE = 0,
|
||||||
|
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||||
|
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_hparams_posnet {
|
struct llama_hparams_posnet {
|
||||||
uint32_t n_embd;
|
uint32_t n_embd;
|
||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
|
|
@ -35,8 +41,6 @@ struct llama_hparams {
|
||||||
uint32_t n_embd_features = 0;
|
uint32_t n_embd_features = 0;
|
||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
uint32_t n_rot;
|
uint32_t n_rot;
|
||||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
|
||||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
|
||||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||||
uint32_t n_expert = 0;
|
uint32_t n_expert = 0;
|
||||||
|
|
@ -98,6 +102,15 @@ struct llama_hparams {
|
||||||
|
|
||||||
std::array<int, 4> rope_sections;
|
std::array<int, 4> rope_sections;
|
||||||
|
|
||||||
|
// Sliding Window Attention (SWA)
|
||||||
|
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
// the size of the sliding window (0 - no SWA)
|
||||||
|
uint32_t n_swa = 0;
|
||||||
|
// if swa_layers[il] == true, then layer il is SWA
|
||||||
|
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||||
|
// by default, all layers are dense
|
||||||
|
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||||
|
|
||||||
// for State Space Models
|
// for State Space Models
|
||||||
uint32_t ssm_d_conv = 0;
|
uint32_t ssm_d_conv = 0;
|
||||||
uint32_t ssm_d_inner = 0;
|
uint32_t ssm_d_inner = 0;
|
||||||
|
|
@ -118,11 +131,13 @@ struct llama_hparams {
|
||||||
bool causal_attn = true;
|
bool causal_attn = true;
|
||||||
bool use_alibi = false;
|
bool use_alibi = false;
|
||||||
bool attn_soft_cap = false;
|
bool attn_soft_cap = false;
|
||||||
|
bool use_kq_norm = true;
|
||||||
|
|
||||||
|
// for Classifiers
|
||||||
|
uint32_t n_cls_out = 1;
|
||||||
|
|
||||||
|
// llama4
|
||||||
uint32_t n_moe_layer_step = 0;
|
uint32_t n_moe_layer_step = 0;
|
||||||
bool use_kq_norm = true;
|
|
||||||
uint32_t n_attn_chunk = 0;
|
|
||||||
// values below seems to be fixed on llama4
|
|
||||||
uint32_t n_no_rope_layer_step = 4;
|
uint32_t n_no_rope_layer_step = 4;
|
||||||
uint32_t n_attn_temp_floor_scale = 8192;
|
uint32_t n_attn_temp_floor_scale = 8192;
|
||||||
float f_attn_temp_scale = 0.1;
|
float f_attn_temp_scale = 0.1;
|
||||||
|
|
@ -135,6 +150,23 @@ struct llama_hparams {
|
||||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||||
|
|
||||||
|
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||||
|
// note that if n_pattern == 0, all layers are SWA
|
||||||
|
// if n_pattern == 1, all layers are dense
|
||||||
|
// example: n_pattern = 3
|
||||||
|
// il == 0: swa
|
||||||
|
// il == 1: swa
|
||||||
|
// il == 2: dense
|
||||||
|
// il == 3: swa
|
||||||
|
// il == 4: swa
|
||||||
|
// il == 5: dense
|
||||||
|
// il == 6: swa
|
||||||
|
// etc ...
|
||||||
|
void set_swa_pattern(uint32_t n_pattern);
|
||||||
|
|
||||||
|
// return true if one of the layers is SWA
|
||||||
|
bool is_swa_any() const;
|
||||||
|
|
||||||
uint32_t n_head(uint32_t il = 0) const;
|
uint32_t n_head(uint32_t il = 0) const;
|
||||||
|
|
||||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,185 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-batch.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_recurrent
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||||
|
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||||
|
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||||
|
public:
|
||||||
|
llama_kv_cache_recurrent(
|
||||||
|
const llama_model & model,
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool offload,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_seq_max);
|
||||||
|
|
||||||
|
~llama_kv_cache_recurrent() = default;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_i
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
|
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||||
|
bool find_slot(const llama_ubatch & ubatch);
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||||
|
int32_t s_copy(int i) const;
|
||||||
|
float s_mask(int i) const;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
|
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||||
|
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||||
|
|
||||||
|
// computed before each graph build
|
||||||
|
uint32_t n = 0;
|
||||||
|
|
||||||
|
// TODO: optimize for recurrent state needs
|
||||||
|
struct kv_cell {
|
||||||
|
llama_pos pos = -1;
|
||||||
|
int32_t src = -1; // used to copy states
|
||||||
|
int32_t tail = -1;
|
||||||
|
|
||||||
|
std::set<llama_seq_id> seq_id;
|
||||||
|
|
||||||
|
bool has_seq_id(const llama_seq_id & id) const {
|
||||||
|
return seq_id.find(id) != seq_id.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_empty() const {
|
||||||
|
return seq_id.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_same_seq(const kv_cell & other) const {
|
||||||
|
return seq_id == other.seq_id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<kv_cell> cells;
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> k_l; // per layer
|
||||||
|
std::vector<ggml_tensor *> v_l;
|
||||||
|
|
||||||
|
private:
|
||||||
|
//const llama_model & model;
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
const uint32_t n_seq_max = 1;
|
||||||
|
|
||||||
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
|
size_t total_size() const;
|
||||||
|
|
||||||
|
size_t size_k_bytes() const;
|
||||||
|
size_t size_v_bytes() const;
|
||||||
|
|
||||||
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||||
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||||
|
|
||||||
|
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||||
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||||
|
public:
|
||||||
|
// used for errors
|
||||||
|
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||||
|
|
||||||
|
// used to create a full-cache state
|
||||||
|
llama_kv_cache_recurrent_state(
|
||||||
|
llama_memory_status status,
|
||||||
|
llama_kv_cache_recurrent * kv);
|
||||||
|
|
||||||
|
// used to create a state from a batch
|
||||||
|
llama_kv_cache_recurrent_state(
|
||||||
|
llama_memory_status status,
|
||||||
|
llama_kv_cache_recurrent * kv,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
|
virtual ~llama_kv_cache_recurrent_state();
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_state_i
|
||||||
|
//
|
||||||
|
|
||||||
|
bool next() override;
|
||||||
|
bool apply() override;
|
||||||
|
|
||||||
|
std::vector<int64_t> & out_ids() override;
|
||||||
|
|
||||||
|
llama_memory_status get_status() const override;
|
||||||
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_recurrent_state specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t get_n_kv() const;
|
||||||
|
uint32_t get_head() const;
|
||||||
|
uint32_t get_size() const;
|
||||||
|
|
||||||
|
ggml_tensor * get_k_l(int32_t il) const;
|
||||||
|
ggml_tensor * get_v_l(int32_t il) const;
|
||||||
|
|
||||||
|
int32_t s_copy(int i) const;
|
||||||
|
float s_mask(int i) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llama_memory_status status;
|
||||||
|
|
||||||
|
llama_kv_cache_recurrent * kv;
|
||||||
|
|
||||||
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
|
size_t i_next = 0;
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
//
|
||||||
|
// data needed for building the compute graph for the current ubatch:
|
||||||
|
// TODO: extract all the state like `head` and `n` here
|
||||||
|
//
|
||||||
|
|
||||||
|
const bool is_full = false;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,252 @@
|
||||||
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
#include "llama-batch.h"
|
||||||
|
#include "llama-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_iswa
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||||
|
const llama_model & model,
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
bool offload,
|
||||||
|
bool swa_full,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
uint32_t n_pad) : hparams(model.hparams) {
|
||||||
|
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||||
|
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||||
|
|
||||||
|
const uint32_t size_base = kv_size;
|
||||||
|
|
||||||
|
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||||
|
|
||||||
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||||
|
if (swa_full) {
|
||||||
|
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
||||||
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
|
|
||||||
|
size_swa = size_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||||
|
|
||||||
|
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||||
|
model, std::move(filter_base), type_k, type_v,
|
||||||
|
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||||
|
0, LLAMA_SWA_TYPE_NONE);
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||||
|
|
||||||
|
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||||
|
model, std::move(filter_swa), type_k, type_v,
|
||||||
|
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||||
|
hparams.n_swa, hparams.swa_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::clear() {
|
||||||
|
kv_base->clear();
|
||||||
|
kv_swa ->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||||
|
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
|
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||||
|
kv_base->seq_keep(seq_id);
|
||||||
|
kv_swa ->seq_keep(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
|
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||||
|
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
|
kv_base->seq_div(seq_id, p0, p1, d);
|
||||||
|
kv_swa ->seq_div(seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
|
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
||||||
|
return kv_swa->seq_pos_min(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
|
return kv_swa->seq_pos_max(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
||||||
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
|
// TODO: if we fail with split_simple, we should attempt different splitting strategies
|
||||||
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||||
|
|
||||||
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
while (sbatch.n_tokens > 0) {
|
||||||
|
auto ubatch = sbatch.split_simple(n_ubatch);
|
||||||
|
|
||||||
|
ubatches.push_back(ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_base = kv_base->prepare(ubatches);
|
||||||
|
if (heads_base.empty()) {
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_swa = kv_swa->prepare(ubatches);
|
||||||
|
if (heads_swa.empty()) {
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||||
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
|
return kv_base->get_size() == kv_swa->get_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||||
|
kv_base->state_write(io, seq_id);
|
||||||
|
kv_swa ->state_write(io, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
|
kv_base->state_read(io, seq_id);
|
||||||
|
kv_swa ->state_read(io, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||||
|
return kv_base.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||||
|
return kv_swa.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_iswa_state
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
|
state_base = kv->get_base()->init_full();
|
||||||
|
state_swa = kv->get_swa ()->init_full();
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
|
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||||
|
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
std::vector<uint32_t> heads_base,
|
||||||
|
std::vector<uint32_t> heads_swa,
|
||||||
|
std::vector<llama_ubatch> ubatches)
|
||||||
|
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||||
|
sbatch(std::move(sbatch)),
|
||||||
|
ubatches(std::move(ubatches)) {
|
||||||
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
|
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||||
|
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||||
|
|
||||||
|
bool llama_kv_cache_unified_iswa_state::next() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
state_base->next();
|
||||||
|
state_swa ->next();
|
||||||
|
|
||||||
|
if (++i_next >= ubatches.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_unified_iswa_state::apply() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res = res & state_base->apply();
|
||||||
|
res = res & state_swa ->apply();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
return sbatch.out_ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
return ubatches[i_next];
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-kv-cache-unified.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_iswa
|
||||||
|
//
|
||||||
|
|
||||||
|
// utilizes two instances of llama_kv_cache_unified
|
||||||
|
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||||
|
|
||||||
|
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||||
|
public:
|
||||||
|
llama_kv_cache_unified_iswa(
|
||||||
|
const llama_model & model,
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
bool offload,
|
||||||
|
bool swa_full,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
uint32_t n_pad);
|
||||||
|
|
||||||
|
~llama_kv_cache_unified_iswa() = default;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_i
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_iswa specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_kv_cache_unified * get_base() const;
|
||||||
|
llama_kv_cache_unified * get_swa () const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||||
|
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
||||||
|
public:
|
||||||
|
// used for errors
|
||||||
|
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
||||||
|
|
||||||
|
// used to create a full-cache state
|
||||||
|
llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv);
|
||||||
|
|
||||||
|
// used to create an update state
|
||||||
|
llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize);
|
||||||
|
|
||||||
|
// used to create a state from a batch
|
||||||
|
llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
std::vector<uint32_t> heads_base,
|
||||||
|
std::vector<uint32_t> heads_swa,
|
||||||
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
|
virtual ~llama_kv_cache_unified_iswa_state();
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_state_i
|
||||||
|
//
|
||||||
|
|
||||||
|
bool next() override;
|
||||||
|
bool apply() override;
|
||||||
|
|
||||||
|
std::vector<int64_t> & out_ids() override;
|
||||||
|
|
||||||
|
llama_memory_status get_status() const override;
|
||||||
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_iswa_state specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * get_base() const;
|
||||||
|
const llama_kv_cache_unified_state * get_swa() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_memory_status status;
|
||||||
|
|
||||||
|
//llama_kv_cache_unified_iswa * kv;
|
||||||
|
|
||||||
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
|
// the index of the next ubatch to process
|
||||||
|
size_t i_next = 0;
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
llama_memory_state_ptr state_base;
|
||||||
|
llama_memory_state_ptr state_swa;
|
||||||
|
};
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,307 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-batch.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
|
#include "llama-kv-cells.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_cparams;
|
||||||
|
struct llama_hparams;
|
||||||
|
struct llama_model;
|
||||||
|
struct llama_context;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified
|
||||||
|
//
|
||||||
|
|
||||||
|
class llama_kv_cache_unified : public llama_memory_i {
|
||||||
|
public:
|
||||||
|
static uint32_t get_padding(const llama_cparams & cparams);
|
||||||
|
|
||||||
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
|
using ubatch_heads = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
struct defrag_info {
|
||||||
|
bool empty() const {
|
||||||
|
return ids.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// contains information about which cell moves where:
|
||||||
|
// - cell i moves to ids[i]
|
||||||
|
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||||
|
std::vector<uint32_t> ids;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_kv_cache_unified(
|
||||||
|
const llama_model & model,
|
||||||
|
layer_filter_cb && filter,
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
bool offload,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
uint32_t n_pad,
|
||||||
|
uint32_t n_swa,
|
||||||
|
llama_swa_type swa_type);
|
||||||
|
|
||||||
|
~llama_kv_cache_unified() = default;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_i
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t get_size() const;
|
||||||
|
|
||||||
|
bool get_has_shift() const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// graph_build API
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t get_n_kv() const;
|
||||||
|
|
||||||
|
// get views of the current state of the cache
|
||||||
|
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
|
||||||
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||||
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// preparation API
|
||||||
|
//
|
||||||
|
|
||||||
|
// find places for the provided ubatches in the cache, returns the head locations
|
||||||
|
// return empty vector on failure
|
||||||
|
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
|
// return the cell position where we can insert the ubatch
|
||||||
|
// return -1 on failure to find a contiguous slot of kv cells
|
||||||
|
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||||
|
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||||
|
|
||||||
|
//
|
||||||
|
// set_input API
|
||||||
|
//
|
||||||
|
|
||||||
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llama_model & model;
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
struct kv_layer {
|
||||||
|
// layer index in the model
|
||||||
|
// note: can be different from the layer index in the KV cache
|
||||||
|
uint32_t il;
|
||||||
|
|
||||||
|
ggml_tensor * k;
|
||||||
|
ggml_tensor * v;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool v_trans = true; // the value tensor is transposed
|
||||||
|
|
||||||
|
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||||
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||||
|
uint32_t head = 0;
|
||||||
|
|
||||||
|
const uint32_t n_seq_max = 1;
|
||||||
|
|
||||||
|
// required padding
|
||||||
|
const uint32_t n_pad = 1;
|
||||||
|
|
||||||
|
// SWA
|
||||||
|
const uint32_t n_swa = 0;
|
||||||
|
|
||||||
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
|
||||||
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
|
llama_kv_cells_unified cells;
|
||||||
|
|
||||||
|
std::vector<kv_layer> layers;
|
||||||
|
|
||||||
|
// model layer id -> KV cache layer id
|
||||||
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||||
|
|
||||||
|
// return non-empty vector if cells have been moved
|
||||||
|
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||||
|
|
||||||
|
size_t total_size() const;
|
||||||
|
|
||||||
|
size_t size_k_bytes() const;
|
||||||
|
size_t size_v_bytes() const;
|
||||||
|
|
||||||
|
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_rope_shift(
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
ggml_context * ctx,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * shift,
|
||||||
|
ggml_tensor * factors,
|
||||||
|
float freq_base,
|
||||||
|
float freq_scale) const;
|
||||||
|
|
||||||
|
llm_graph_result_ptr build_graph_shift(
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
ggml_context * ctx,
|
||||||
|
ggml_cgraph * gf) const;
|
||||||
|
|
||||||
|
llm_graph_result_ptr build_graph_defrag(
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
ggml_context * ctx,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
const defrag_info & dinfo) const;
|
||||||
|
|
||||||
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||||
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||||
|
|
||||||
|
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||||
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||||
|
public:
|
||||||
|
// some shorthands
|
||||||
|
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||||
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
|
// used for errors
|
||||||
|
llama_kv_cache_unified_state(llama_memory_status status);
|
||||||
|
|
||||||
|
// used to create a full-cache state
|
||||||
|
llama_kv_cache_unified_state(
|
||||||
|
llama_kv_cache_unified * kv);
|
||||||
|
|
||||||
|
// used to create an update state
|
||||||
|
llama_kv_cache_unified_state(
|
||||||
|
llama_kv_cache_unified * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool do_shift,
|
||||||
|
defrag_info dinfo);
|
||||||
|
|
||||||
|
// used to create a decode state from a batch
|
||||||
|
llama_kv_cache_unified_state(
|
||||||
|
llama_kv_cache_unified * kv,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
ubatch_heads heads,
|
||||||
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
|
virtual ~llama_kv_cache_unified_state();
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_state_i
|
||||||
|
//
|
||||||
|
|
||||||
|
bool next() override;
|
||||||
|
bool apply() override;
|
||||||
|
|
||||||
|
std::vector<int64_t> & out_ids() override;
|
||||||
|
|
||||||
|
llama_memory_status get_status() const override;
|
||||||
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_kv_cache_unified_state specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t get_n_kv() const;
|
||||||
|
|
||||||
|
// get views of the current state of the cache
|
||||||
|
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||||
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||||
|
|
||||||
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||||
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||||
|
|
||||||
|
void set_input_k_shift(ggml_tensor * dst) const;
|
||||||
|
|
||||||
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_memory_status status;
|
||||||
|
|
||||||
|
llama_kv_cache_unified * kv;
|
||||||
|
llama_context * lctx;
|
||||||
|
|
||||||
|
//
|
||||||
|
// update state
|
||||||
|
//
|
||||||
|
|
||||||
|
bool do_shift = false;
|
||||||
|
|
||||||
|
defrag_info dinfo;
|
||||||
|
|
||||||
|
//
|
||||||
|
// batch processing state
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
|
// the index of the next ubatch to process
|
||||||
|
size_t i_next = 0;
|
||||||
|
|
||||||
|
ubatch_heads heads;
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
//
|
||||||
|
// data needed for building the compute graph for the current ubatch:
|
||||||
|
//
|
||||||
|
|
||||||
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
|
// as the cache gets filled, the benefit from this heuristic disappears
|
||||||
|
int32_t n_kv;
|
||||||
|
|
||||||
|
// the beginning of the current slot in which the ubatch will be inserted
|
||||||
|
int32_t head;
|
||||||
|
};
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,413 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
#include "llama-io.h"
|
|
||||||
#include "llama-graph.h"
|
|
||||||
#include "llama-memory.h"
|
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct llama_cparams;
|
|
||||||
struct llama_hparams;
|
|
||||||
struct llama_ubatch;
|
|
||||||
struct llama_sbatch;
|
|
||||||
struct llama_model;
|
|
||||||
struct llama_context;
|
|
||||||
|
|
||||||
struct llama_kv_cache : public llama_memory_i {
|
|
||||||
virtual ~llama_kv_cache() = default;
|
|
||||||
|
|
||||||
// call if batch processing fails - restores the cache state
|
|
||||||
virtual void restore() = 0;
|
|
||||||
|
|
||||||
// call after successful batch processing - clears any pending state
|
|
||||||
virtual void commit() = 0;
|
|
||||||
|
|
||||||
// process any pending defrag/shift/etc. operations
|
|
||||||
// optionally call once before processing a new batch
|
|
||||||
virtual bool update(llama_context & lctx) = 0;
|
|
||||||
|
|
||||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
|
||||||
virtual void defrag_sched(float thold) = 0;
|
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
|
||||||
virtual void set_full() = 0;
|
|
||||||
|
|
||||||
//
|
|
||||||
// batch processing
|
|
||||||
//
|
|
||||||
|
|
||||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
|
||||||
|
|
||||||
// different KV caches require different batch splitting strategies
|
|
||||||
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
|
|
||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
|
||||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
|
||||||
|
|
||||||
// getters
|
|
||||||
virtual int32_t get_n_tokens() const = 0;
|
|
||||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
|
||||||
virtual llama_pos get_pos_max() const = 0;
|
|
||||||
virtual bool get_can_shift() const = 0;
|
|
||||||
|
|
||||||
bool get_can_edit() const override { return get_can_shift(); }
|
|
||||||
|
|
||||||
//
|
|
||||||
// state write/read
|
|
||||||
//
|
|
||||||
|
|
||||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
|
||||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache_guard
|
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_kv_cache_guard {
|
|
||||||
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
|
||||||
|
|
||||||
~llama_kv_cache_guard() {
|
|
||||||
kv->restore();
|
|
||||||
}
|
|
||||||
|
|
||||||
void commit() {
|
|
||||||
kv->commit();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
llama_kv_cache * kv;
|
|
||||||
};
|
|
||||||
|
|
||||||
// block of KV slots to move when defragging
|
|
||||||
struct llama_kv_defrag_move {
|
|
||||||
uint32_t src;
|
|
||||||
uint32_t dst;
|
|
||||||
uint32_t len;
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache_unified
|
|
||||||
//
|
|
||||||
|
|
||||||
// TODO: add notion of max sequences
|
|
||||||
class llama_kv_cache_unified : public llama_kv_cache {
|
|
||||||
public:
|
|
||||||
struct kv_cell {
|
|
||||||
llama_pos pos = -1;
|
|
||||||
llama_pos delta = 0;
|
|
||||||
|
|
||||||
std::set<llama_seq_id> seq_id;
|
|
||||||
|
|
||||||
bool has_seq_id(const llama_seq_id & id) const {
|
|
||||||
return seq_id.find(id) != seq_id.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_empty() const {
|
|
||||||
return seq_id.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_same_seq(const kv_cell & other) const {
|
|
||||||
return seq_id == other.seq_id;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static uint32_t get_padding(const llama_cparams & cparams);
|
|
||||||
|
|
||||||
llama_kv_cache_unified(
|
|
||||||
const llama_model & model,
|
|
||||||
ggml_type type_k,
|
|
||||||
ggml_type type_v,
|
|
||||||
bool v_trans,
|
|
||||||
bool offload,
|
|
||||||
uint32_t kv_size,
|
|
||||||
uint32_t padding);
|
|
||||||
|
|
||||||
~llama_kv_cache_unified() = default;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_memory_i
|
|
||||||
//
|
|
||||||
|
|
||||||
void clear() override;
|
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache
|
|
||||||
//
|
|
||||||
|
|
||||||
void restore() override;
|
|
||||||
void commit() override;
|
|
||||||
|
|
||||||
bool update(llama_context & ctx) override;
|
|
||||||
|
|
||||||
void defrag_sched(float thold) override;
|
|
||||||
|
|
||||||
void set_full() override;
|
|
||||||
|
|
||||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
|
||||||
|
|
||||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
|
||||||
|
|
||||||
// updates the cache head
|
|
||||||
// Note: On success, it's important that cache.head points
|
|
||||||
// to the first cell of the slot.
|
|
||||||
bool find_slot(const llama_ubatch & batch) override;
|
|
||||||
|
|
||||||
int32_t get_n_tokens() const override;
|
|
||||||
int32_t get_used_cells() const override;
|
|
||||||
|
|
||||||
// TODO: better data structures to reduce the cost of this operation
|
|
||||||
llama_pos get_pos_max() const override;
|
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
|
||||||
|
|
||||||
// state write/load
|
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
|
||||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
|
||||||
// cannot be freely changed after a slot has been allocated.
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t size = 0;
|
|
||||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
|
||||||
|
|
||||||
// computed before each graph build
|
|
||||||
uint32_t n = 0;
|
|
||||||
|
|
||||||
std::vector<kv_cell> cells;
|
|
||||||
|
|
||||||
std::vector<ggml_tensor *> k_l; // per layer
|
|
||||||
std::vector<ggml_tensor *> v_l;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const llama_model & model;
|
|
||||||
const llama_hparams & hparams;
|
|
||||||
|
|
||||||
bool has_shift = false;
|
|
||||||
bool do_defrag = false;
|
|
||||||
|
|
||||||
bool v_trans = true; // the value tensor is transposed
|
|
||||||
bool can_shift = false;
|
|
||||||
|
|
||||||
// required padding
|
|
||||||
uint32_t padding = 1;
|
|
||||||
|
|
||||||
ggml_type type_k = GGML_TYPE_F16;
|
|
||||||
ggml_type type_v = GGML_TYPE_F16;
|
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
||||||
|
|
||||||
// defrag
|
|
||||||
struct {
|
|
||||||
std::vector<llama_kv_defrag_move> moves;
|
|
||||||
} defrag_info;
|
|
||||||
|
|
||||||
// return true if cells have been moved
|
|
||||||
bool defrag_prepare(int32_t n_max_nodes);
|
|
||||||
|
|
||||||
// commit/restore cache
|
|
||||||
struct slot_range {
|
|
||||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
|
||||||
uint32_t c1 = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// pending cell updates that are not yet committed
|
|
||||||
struct {
|
|
||||||
std::vector<slot_range> ranges;
|
|
||||||
} pending;
|
|
||||||
|
|
||||||
// find how many cells are currently in use
|
|
||||||
uint32_t cell_max() const;
|
|
||||||
|
|
||||||
size_t total_size() const;
|
|
||||||
|
|
||||||
size_t size_k_bytes() const;
|
|
||||||
size_t size_v_bytes() const;
|
|
||||||
|
|
||||||
ggml_tensor * build_rope_shift(
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_context * ctx,
|
|
||||||
ggml_tensor * cur,
|
|
||||||
ggml_tensor * shift,
|
|
||||||
ggml_tensor * factors,
|
|
||||||
float freq_base,
|
|
||||||
float freq_scale) const;
|
|
||||||
|
|
||||||
llm_graph_result_ptr build_graph_shift(
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_context * ctx,
|
|
||||||
ggml_cgraph * gf) const;
|
|
||||||
|
|
||||||
llm_graph_result_ptr build_graph_defrag(
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_context * ctx,
|
|
||||||
ggml_cgraph * gf,
|
|
||||||
const std::vector<llama_kv_defrag_move> & moves) const;
|
|
||||||
|
|
||||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
||||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
||||||
|
|
||||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
|
||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache_recurrent
|
|
||||||
//
|
|
||||||
|
|
||||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
|
||||||
public:
|
|
||||||
struct kv_cell {
|
|
||||||
llama_pos pos = -1;
|
|
||||||
int32_t src = -1; // used to copy states
|
|
||||||
int32_t tail = -1;
|
|
||||||
|
|
||||||
std::set<llama_seq_id> seq_id;
|
|
||||||
|
|
||||||
bool has_seq_id(const llama_seq_id & id) const {
|
|
||||||
return seq_id.find(id) != seq_id.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_empty() const {
|
|
||||||
return seq_id.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_same_seq(const kv_cell & other) const {
|
|
||||||
return seq_id == other.seq_id;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
llama_kv_cache_recurrent(
|
|
||||||
const llama_model & model,
|
|
||||||
ggml_type type_k,
|
|
||||||
ggml_type type_v,
|
|
||||||
bool offload,
|
|
||||||
uint32_t kv_size);
|
|
||||||
|
|
||||||
~llama_kv_cache_recurrent() = default;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_memory_i
|
|
||||||
//
|
|
||||||
|
|
||||||
void clear() override;
|
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache
|
|
||||||
//
|
|
||||||
|
|
||||||
void restore() override;
|
|
||||||
void commit() override;
|
|
||||||
|
|
||||||
bool update(llama_context & lctx) override;
|
|
||||||
|
|
||||||
void defrag_sched(float thold) override;
|
|
||||||
|
|
||||||
void set_full() override;
|
|
||||||
|
|
||||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
|
||||||
|
|
||||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
|
||||||
|
|
||||||
bool find_slot(const llama_ubatch & batch) override;
|
|
||||||
|
|
||||||
int32_t get_n_tokens() const override;
|
|
||||||
int32_t get_used_cells() const override;
|
|
||||||
|
|
||||||
// TODO: better data structures to reduce the cost of this operation
|
|
||||||
llama_pos get_pos_max() const override;
|
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
|
||||||
|
|
||||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
|
||||||
int32_t s_copy(int i) const;
|
|
||||||
float s_mask(int i) const;
|
|
||||||
|
|
||||||
// state write/load
|
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
|
||||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
|
||||||
// cannot be freely changed after a slot has been allocated.
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t size = 0;
|
|
||||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
|
||||||
|
|
||||||
// computed before each graph build
|
|
||||||
uint32_t n = 0;
|
|
||||||
|
|
||||||
std::vector<kv_cell> cells;
|
|
||||||
|
|
||||||
std::vector<ggml_tensor *> k_l; // per layer
|
|
||||||
std::vector<ggml_tensor *> v_l;
|
|
||||||
|
|
||||||
private:
|
|
||||||
//const llama_model & model;
|
|
||||||
const llama_hparams & hparams;
|
|
||||||
|
|
||||||
// commit/restore cache
|
|
||||||
// TODO: rework for recurrent cache
|
|
||||||
struct slot_range {
|
|
||||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
|
||||||
uint32_t c1 = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// pending cell updates that are not yet committed
|
|
||||||
struct {
|
|
||||||
std::vector<slot_range> ranges;
|
|
||||||
} pending;
|
|
||||||
|
|
||||||
ggml_type type_k = GGML_TYPE_F16;
|
|
||||||
ggml_type type_v = GGML_TYPE_F16;
|
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
||||||
|
|
||||||
// find how many cells are currently in use
|
|
||||||
uint32_t cell_max() const;
|
|
||||||
|
|
||||||
size_t total_size() const;
|
|
||||||
|
|
||||||
size_t size_k_bytes() const;
|
|
||||||
size_t size_v_bytes() const;
|
|
||||||
|
|
||||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
||||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
||||||
|
|
||||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
|
||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
//
|
|
||||||
// kv cache view
|
|
||||||
//
|
|
||||||
|
|
||||||
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
|
||||||
|
|
||||||
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
|
||||||
|
|
@ -0,0 +1,410 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "llama-cparams.h"
|
||||||
|
|
||||||
|
#include <bitset>
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||||
|
// TODO: add unit tests
|
||||||
|
class llama_kv_cells_unified {
|
||||||
|
public:
|
||||||
|
void reset() {
|
||||||
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||||
|
pos[i] = -1;
|
||||||
|
shift[i] = 0;
|
||||||
|
seq[i].reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
has_shift = false;
|
||||||
|
|
||||||
|
used.clear();
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
seq_pos[s].clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_shift() {
|
||||||
|
has_shift = false;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < shift.size(); ++i) {
|
||||||
|
shift[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t size() const {
|
||||||
|
return pos.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(uint32_t n) {
|
||||||
|
pos.resize(n);
|
||||||
|
shift.resize(n);
|
||||||
|
seq.resize(n);
|
||||||
|
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_empty(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
||||||
|
|
||||||
|
return pos[i] == -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t get_used() const {
|
||||||
|
return used.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// the index of the first cell that is used
|
||||||
|
// return 0 if no cells are used
|
||||||
|
uint32_t used_min() const {
|
||||||
|
return used.empty() ? 0 : *used.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
// the index of the last cell that is used + 1
|
||||||
|
// return 0 if no cells are used
|
||||||
|
uint32_t used_max_p1() const {
|
||||||
|
return used.empty() ? 0 : *used.rbegin() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool get_has_shift() const {
|
||||||
|
return has_shift;
|
||||||
|
}
|
||||||
|
|
||||||
|
// move cell isrc to idst (used during defrag)
|
||||||
|
void mv(uint32_t isrc, uint32_t idst) {
|
||||||
|
assert(isrc < pos.size());
|
||||||
|
assert(idst < pos.size());
|
||||||
|
|
||||||
|
pos [idst] = pos [isrc];
|
||||||
|
shift[idst] = shift[isrc];
|
||||||
|
seq [idst] = seq [isrc];
|
||||||
|
|
||||||
|
pos [isrc] = -1;
|
||||||
|
shift[isrc] = 0;
|
||||||
|
seq [isrc].reset();
|
||||||
|
|
||||||
|
used.erase (isrc);
|
||||||
|
used.insert(idst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||||
|
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||||
|
assert(i + n <= pos.size());
|
||||||
|
|
||||||
|
llama_kv_cells_unified res;
|
||||||
|
|
||||||
|
res.resize(n);
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < n; ++j) {
|
||||||
|
res.pos[j] = pos[i + j];
|
||||||
|
res.seq[j] = seq[i + j];
|
||||||
|
|
||||||
|
assert(shift[i + j] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||||
|
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||||
|
assert(i + other.pos.size() <= pos.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
|
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||||
|
used.insert(i + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||||
|
used.erase(i + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[i + j] != -1) {
|
||||||
|
seq_pos_rm(i + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
pos[i + j] = other.pos[j];
|
||||||
|
seq[i + j] = other.seq[j];
|
||||||
|
|
||||||
|
if (pos[i + j] != -1) {
|
||||||
|
seq_pos_add(i + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(shift[i + j] == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear a non-empty cell
|
||||||
|
void rm(uint32_t i) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
seq_pos_rm(i);
|
||||||
|
|
||||||
|
pos[i] = -1;
|
||||||
|
seq[i].reset();
|
||||||
|
|
||||||
|
used.erase(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell has seq_id
|
||||||
|
// return true if the cell becomes empty
|
||||||
|
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(seq[i].test(seq_id));
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
|
||||||
|
seq[i].reset(seq_id);
|
||||||
|
seq_pos[seq_id].erase(pos[i]);
|
||||||
|
|
||||||
|
if (seq[i].none()) {
|
||||||
|
pos[i] = -1;
|
||||||
|
|
||||||
|
used.erase(i);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
||||||
|
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
|
||||||
|
if (seq[i].test(seq_id)) {
|
||||||
|
seq_pos_rm(i);
|
||||||
|
seq[i].reset();
|
||||||
|
|
||||||
|
seq[i].set(seq_id);
|
||||||
|
seq_pos[seq_id].insert(pos[i]);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq[i].any()) {
|
||||||
|
seq_pos_rm(i);
|
||||||
|
seq[i].reset();
|
||||||
|
|
||||||
|
pos[i] = -1;
|
||||||
|
|
||||||
|
used.erase(i);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(pos[i] == -1);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// number of different sequences in the cell
|
||||||
|
int seq_count(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return seq[i].count();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the cell contains seq_id
|
||||||
|
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
|
||||||
|
return seq[i].test(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty and the seq_id is not in the cell
|
||||||
|
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
assert(!seq[i].test(seq_id));
|
||||||
|
|
||||||
|
seq[i].set(seq_id);
|
||||||
|
seq_pos[seq_id].insert(pos[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the sequence id of this cell
|
||||||
|
// note: call only for cells with exactly one sequence
|
||||||
|
llama_seq_id seq_get(uint32_t i) const {
|
||||||
|
assert(seq[i].count() == 1);
|
||||||
|
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
if (seq[i].test(s)) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the minimum position of sequence seq_id currently present in any of the cells
|
||||||
|
// return -1 if the sequence is not present
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||||
|
|
||||||
|
if (seq_pos[seq_id].empty()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return *seq_pos[seq_id].begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
// the maximum position of sequence seq_id currently present in any of the cells
|
||||||
|
// return -1 if the sequence is not present
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||||
|
|
||||||
|
if (seq_pos[seq_id].empty()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return *seq_pos[seq_id].rbegin();
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
llama_pos pos_get(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return pos[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
llama_pos get_shift(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return shift[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if a cell is not empty and its position is within [p0, p1)
|
||||||
|
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
|
||||||
|
return pos[i] >= p0 && pos[i] < p1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the position of an empty cell
|
||||||
|
// does not modify "has_shift"
|
||||||
|
// note: call only if the cell is empty
|
||||||
|
void pos_set(uint32_t i, llama_pos p) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] == -1);
|
||||||
|
assert(seq[i].none());
|
||||||
|
|
||||||
|
pos[i] = p;
|
||||||
|
|
||||||
|
used.insert(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pos[i] = pos[i] + d
|
||||||
|
// sets "has_shift" to true
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
bool pos_add(uint32_t i, llama_pos d) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
seq_pos_rm(i);
|
||||||
|
|
||||||
|
pos[i] += d;
|
||||||
|
shift[i] += d;
|
||||||
|
|
||||||
|
seq_pos_add(i);
|
||||||
|
|
||||||
|
has_shift = true;
|
||||||
|
|
||||||
|
if (pos[i] < 0) {
|
||||||
|
seq_pos_rm(i);
|
||||||
|
|
||||||
|
seq[i].reset();
|
||||||
|
pos[i] = -1;
|
||||||
|
|
||||||
|
used.erase(i);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pos[i] = pos[i] / d
|
||||||
|
// sets "has_shift" to true
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
void pos_div(uint32_t i, int d) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
const llama_pos p_old = pos[i];
|
||||||
|
|
||||||
|
seq_pos_rm(i);
|
||||||
|
|
||||||
|
pos[i] /= d;
|
||||||
|
shift[i] += p_old - pos[i];
|
||||||
|
|
||||||
|
seq_pos_add(i);
|
||||||
|
|
||||||
|
has_shift = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool has_shift = false;
|
||||||
|
|
||||||
|
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||||
|
std::set<uint32_t> used;
|
||||||
|
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
|
||||||
|
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||||
|
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||||
|
//
|
||||||
|
// cells.pos_add(x, shift_x);
|
||||||
|
// cells.pos_div(y, shift_y);
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// if (cells.has_shift()) {
|
||||||
|
// for (int i = 0; i < n; ++i) {
|
||||||
|
// auto shift_i = cells.get_shift(i);
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
// cells.reset_shift();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
std::vector<llama_pos> shift;
|
||||||
|
|
||||||
|
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
|
||||||
|
|
||||||
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||||
|
std::vector<bits_t> seq;
|
||||||
|
|
||||||
|
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
||||||
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||||
|
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||||
|
|
||||||
|
// helper functions for updating `seq_pos`, once cell at a time:
|
||||||
|
|
||||||
|
// remove cell i
|
||||||
|
void seq_pos_rm(uint32_t i) {
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
if (seq[i].test(s)) {
|
||||||
|
seq_pos[s].erase(pos[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add cell i
|
||||||
|
void seq_pos_add(uint32_t i) {
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
if (seq[i].test(s)) {
|
||||||
|
seq_pos[s].insert(pos[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -1 +1,42 @@
|
||||||
#include "llama-memory.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
|
||||||
|
bool has_update = false;
|
||||||
|
|
||||||
|
switch (s0) {
|
||||||
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
|
{
|
||||||
|
has_update = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
|
{
|
||||||
|
return s0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (s1) {
|
||||||
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
|
{
|
||||||
|
has_update = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
|
{
|
||||||
|
return s1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if either status has an update, then the combined status has an update
|
||||||
|
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,30 +2,116 @@
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
class llama_io_write_i;
|
||||||
|
class llama_io_read_i;
|
||||||
|
|
||||||
struct llama_memory_params {
|
struct llama_memory_params {
|
||||||
// kv cache
|
// kv cache
|
||||||
ggml_type type_k;
|
ggml_type type_k;
|
||||||
ggml_type type_v;
|
ggml_type type_v;
|
||||||
|
|
||||||
// parameters for other types of memory
|
// use full-size SWA cache
|
||||||
// ...
|
bool swa_full;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_memory_status {
|
||||||
|
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||||
|
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
||||||
|
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
||||||
|
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||||
|
};
|
||||||
|
|
||||||
|
// helper function for combining the status of two memory states
|
||||||
|
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||||
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||||
|
|
||||||
|
// the interface for managing the memory state during batch processing
|
||||||
|
// this interface is implemented per memory type. see:
|
||||||
|
// - llama_kv_cache_unified_state
|
||||||
|
// - llama_kv_cache_unified_iswa_state
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
||||||
|
//
|
||||||
|
// TODO: rename to llama_memory_context_i ?
|
||||||
|
struct llama_memory_state_i {
|
||||||
|
virtual ~llama_memory_state_i() = default;
|
||||||
|
|
||||||
|
// consume the current ubatch from the state and proceed to the next one
|
||||||
|
// return false if we are done
|
||||||
|
virtual bool next() = 0;
|
||||||
|
|
||||||
|
// apply the memory state for the current ubatch to the memory object
|
||||||
|
// return false on failure
|
||||||
|
virtual bool apply() = 0;
|
||||||
|
|
||||||
|
// TODO: this might get reworked in the future when refactoring llama_batch
|
||||||
|
virtual std::vector<int64_t> & out_ids() = 0;
|
||||||
|
|
||||||
|
// get the current ubatch
|
||||||
|
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||||
|
|
||||||
|
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
||||||
|
virtual llama_memory_status get_status() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
||||||
|
|
||||||
// general concept of LLM memory
|
// general concept of LLM memory
|
||||||
// the KV cache is a type of LLM memory, but there can be other types
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
class llama_memory_i {
|
struct llama_memory_i {
|
||||||
public:
|
|
||||||
virtual ~llama_memory_i() = default;
|
virtual ~llama_memory_i() = default;
|
||||||
|
|
||||||
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||||
|
// return a state object containing the ubatches and KV cache state required to process them
|
||||||
|
// check the llama_memory_state_i::get_status() for the result
|
||||||
|
virtual llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) = 0;
|
||||||
|
|
||||||
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
|
virtual llama_memory_state_ptr init_full() = 0;
|
||||||
|
|
||||||
|
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||||
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||||
|
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
|
|
||||||
|
// getters
|
||||||
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
||||||
|
//
|
||||||
|
// ops
|
||||||
|
//
|
||||||
|
|
||||||
virtual void clear() = 0;
|
virtual void clear() = 0;
|
||||||
|
|
||||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
|
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||||
|
|
||||||
virtual bool get_can_edit() const = 0;
|
//
|
||||||
|
// state write/read
|
||||||
|
//
|
||||||
|
|
||||||
|
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||||
|
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||||
|
|
||||||
|
// TODO: temporary until the llama_kv_cache is removed from the public API
|
||||||
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
|
virtual ~llama_kv_cache() = default;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -401,7 +401,7 @@ struct llama_mmap::impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error("PrefetchVirtualMemory unavailable");
|
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -470,7 +470,7 @@ llama_model_loader::llama_model_loader(
|
||||||
|
|
||||||
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
||||||
if (!meta) {
|
if (!meta) {
|
||||||
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
|
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||||
|
|
@ -529,7 +529,7 @@ llama_model_loader::llama_model_loader(
|
||||||
};
|
};
|
||||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
||||||
if (!ctx_gguf) {
|
if (!ctx_gguf) {
|
||||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
|
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
|
||||||
}
|
}
|
||||||
|
|
||||||
// check idx
|
// check idx
|
||||||
|
|
@ -823,13 +823,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||||
mappings.reserve(files.size());
|
mappings.reserve(files.size());
|
||||||
mmaps_used.reserve(files.size());
|
mmaps_used.reserve(files.size());
|
||||||
for (const auto & file : files) {
|
for (const auto & file : files) {
|
||||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
bool is_numa = false;
|
||||||
if (!reg) {
|
|
||||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
if (dev) {
|
||||||
|
auto * reg = ggml_backend_dev_backend_reg(dev);
|
||||||
|
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||||
|
if (is_numa_fn) {
|
||||||
|
is_numa = is_numa_fn();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
|
||||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
|
||||||
mmaps_used.emplace_back(mapping->size(), 0);
|
mmaps_used.emplace_back(mapping->size(), 0);
|
||||||
if (mlock_mmaps) {
|
if (mlock_mmaps) {
|
||||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -401,7 +401,10 @@ struct llama_model {
|
||||||
|
|
||||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||||
|
|
||||||
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
|
float get_rope_freq_base (const llama_cparams & cparams, int il) const;
|
||||||
|
float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
|
||||||
|
|
||||||
|
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
|
||||||
|
|
||||||
// note: can mutate `cparams`
|
// note: can mutate `cparams`
|
||||||
// TODO: move this to new llm_arch_model_i interface
|
// TODO: move this to new llm_arch_model_i interface
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,12 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
||||||
|
struct tensor_quantization {
|
||||||
|
std::string name;
|
||||||
|
ggml_type quant = GGML_TYPE_COUNT;
|
||||||
|
};
|
||||||
|
|
||||||
static void zeros(std::ofstream & file, size_t n) {
|
static void zeros(std::ofstream & file, size_t n) {
|
||||||
char zero = 0;
|
char zero = 0;
|
||||||
for (size_t i = 0; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
|
@ -48,12 +54,6 @@ struct quantize_state_impl {
|
||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
// changes to this struct must be replicated in quantize.cpp
|
|
||||||
struct tensor_quantization {
|
|
||||||
std::string name;
|
|
||||||
ggml_type quant = GGML_TYPE_COUNT;
|
|
||||||
};
|
|
||||||
|
|
||||||
static void llama_tensor_dequantize_impl(
|
static void llama_tensor_dequantize_impl(
|
||||||
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||||
const size_t nelements, const int nthread
|
const size_t nelements, const int nthread
|
||||||
|
|
@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
// unless the user specifies a type
|
// unless the user specifies a type
|
||||||
if (params->tensor_types) {
|
if (params->tensor_types) {
|
||||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||||
|
const std::string tensor_name(tensor->name);
|
||||||
for (const auto & [tname, qtype] : tensor_types) {
|
for (const auto & [tname, qtype] : tensor_types) {
|
||||||
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
||||||
if (qtype != new_type) {
|
if (qtype != new_type) {
|
||||||
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
||||||
|
new_type = qtype;
|
||||||
|
break; // if two or more types are specified for the tensor, first match wins
|
||||||
}
|
}
|
||||||
new_type = qtype;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||||
new_type = params->token_embedding_type;
|
new_type = params->token_embedding_type;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have enough values the operation was a success
|
// if we have enough values the operation was a success
|
||||||
if (filtered_tokens.size() >= ctx->min_keep) {
|
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
|
||||||
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||||
cur_p->size = filtered_tokens.size();
|
cur_p->size = filtered_tokens.size();
|
||||||
min_p_applied = true;
|
min_p_applied = true;
|
||||||
|
|
@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
|
||||||
cum_sum += cur_p->data[idx].p;
|
cum_sum += cur_p->data[idx].p;
|
||||||
|
|
||||||
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
||||||
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
|
||||||
last_idx = i + 1;
|
last_idx = i + 1;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
||||||
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
|
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
|
||||||
// at the beginning tokenization score is zero
|
// at the beginning tokenization score is zero
|
||||||
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
||||||
|
|
||||||
|
|
@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
|
||||||
const double challenger_score = current_best.score_sum + token_score;
|
const double challenger_score = current_best.score_sum + token_score;
|
||||||
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
||||||
if (challenger_score > current_champ.score_sum) {
|
if (challenger_score > current_champ.score_sum) {
|
||||||
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
|
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
|
||||||
current_champ = challenger;
|
current_champ = challenger;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
|
||||||
prefix_offset = input_offset + n_utf8_code_units;
|
prefix_offset = input_offset + n_utf8_code_units;
|
||||||
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
||||||
if (challenger_score > current_champ.score_sum) {
|
if (challenger_score > current_champ.score_sum) {
|
||||||
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
|
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
|
||||||
current_champ = challenger;
|
current_champ = challenger;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1007,7 +1007,7 @@ private:
|
||||||
struct best_tokenization {
|
struct best_tokenization {
|
||||||
llama_token token_id;
|
llama_token token_id;
|
||||||
size_t input_offset;
|
size_t input_offset;
|
||||||
float score_sum;
|
double score_sum;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
||||||
|
|
@ -2070,9 +2070,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|
|
||||||
std::string model_name;
|
std::string model_name;
|
||||||
std::string tokenizer_pre;
|
std::string tokenizer_pre;
|
||||||
|
std::string general_arch;
|
||||||
|
|
||||||
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
|
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
|
||||||
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
||||||
|
ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
|
||||||
|
|
||||||
// model name to lowercase
|
// model name to lowercase
|
||||||
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
||||||
|
|
@ -2081,9 +2083,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// set attributes by model/tokenizer name
|
// set attributes by model/tokenizer/architecture name
|
||||||
if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
|
if (false
|
||||||
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|
||||||
|
|| _contains_any(general_arch, {"nomic-bert-moe"})
|
||||||
|
) {
|
||||||
|
if (token_to_id.count("<mask>") == 0) {
|
||||||
|
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
|
||||||
|
} else {
|
||||||
|
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||||
|
}
|
||||||
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
||||||
for (auto id : cache_special_tokens) {
|
for (auto id : cache_special_tokens) {
|
||||||
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
struct llama_model_params params) {
|
struct llama_model_params params) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
|
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned cur_percentage = 0;
|
unsigned cur_percentage = 0;
|
||||||
if (params.progress_callback == NULL) {
|
if (params.progress_callback == NULL) {
|
||||||
params.progress_callback_user_data = &cur_percentage;
|
params.progress_callback_user_data = &cur_percentage;
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
|
#include <cinttypes>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
@ -15,22 +16,26 @@
|
||||||
#define KEY_FTYPE "general.file_type"
|
#define KEY_FTYPE "general.file_type"
|
||||||
#define KEY_NAME "general.name"
|
#define KEY_NAME "general.name"
|
||||||
#define KEY_DESCRIPTION "general.description"
|
#define KEY_DESCRIPTION "general.description"
|
||||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
|
||||||
|
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
|
||||||
#define KEY_USE_GELU "clip.use_gelu"
|
#define KEY_USE_GELU "clip.use_gelu"
|
||||||
#define KEY_USE_SILU "clip.use_silu"
|
#define KEY_USE_SILU "clip.use_silu"
|
||||||
#define KEY_N_EMBD "clip.vision.embedding_length"
|
|
||||||
#define KEY_N_FF "clip.vision.feed_forward_length"
|
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||||
#define KEY_N_BLOCK "clip.vision.block_count"
|
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||||
#define KEY_N_HEAD "clip.vision.attention.head_count"
|
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||||
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
|
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||||
#define KEY_PROJ_DIM "clip.vision.projection_dim"
|
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||||
|
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||||
|
|
||||||
|
// vision-specific
|
||||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
|
||||||
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
|
|
@ -38,6 +43,11 @@
|
||||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||||
|
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||||
|
|
||||||
|
// audio-specific
|
||||||
|
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||||
|
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
@ -94,6 +104,13 @@
|
||||||
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
||||||
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
||||||
|
|
||||||
|
// ultravox
|
||||||
|
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||||
|
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||||
|
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||||
|
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||||
|
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
||||||
|
|
||||||
// align x to upper multiple of n
|
// align x to upper multiple of n
|
||||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||||
|
|
||||||
|
|
@ -109,7 +126,11 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_IDEFICS3,
|
PROJECTOR_TYPE_IDEFICS3,
|
||||||
PROJECTOR_TYPE_PIXTRAL,
|
PROJECTOR_TYPE_PIXTRAL,
|
||||||
PROJECTOR_TYPE_QWEN25VL,
|
PROJECTOR_TYPE_QWEN25VL,
|
||||||
|
PROJECTOR_TYPE_ULTRAVOX,
|
||||||
PROJECTOR_TYPE_INTERNVL,
|
PROJECTOR_TYPE_INTERNVL,
|
||||||
|
PROJECTOR_TYPE_LLAMA4,
|
||||||
|
PROJECTOR_TYPE_QWEN2A,
|
||||||
|
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -124,7 +145,11 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||||
|
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
||||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||||
|
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||||
|
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||||
|
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
|
@ -144,8 +169,10 @@ struct clip_image_u8 {
|
||||||
std::vector<uint8_t> buf;
|
std::vector<uint8_t> buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
// RGB float32 image (NHWC)
|
// For images, buf.size() == nx*ny*3
|
||||||
// Memory layout: RGBRGBRGB...
|
// Memory layout: RGBRGBRGB...
|
||||||
|
// For audio, only one channel is used, buf.size() == nx*ny
|
||||||
|
// nx will be n_frames and ny will be n_mel
|
||||||
struct clip_image_f32 {
|
struct clip_image_f32 {
|
||||||
int nx;
|
int nx;
|
||||||
int ny;
|
int ny;
|
||||||
|
|
@ -239,9 +266,20 @@ struct clip_image_u8_batch {
|
||||||
|
|
||||||
struct clip_image_f32_batch {
|
struct clip_image_f32_batch {
|
||||||
std::vector<clip_image_f32_ptr> entries;
|
std::vector<clip_image_f32_ptr> entries;
|
||||||
|
bool is_audio = false;
|
||||||
|
|
||||||
|
// for llava-uhd style models, we need to know the grid size
|
||||||
|
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
||||||
|
int grid_x = 0;
|
||||||
|
int grid_y = 0;
|
||||||
|
|
||||||
clip_image_f32_batch clone() const {
|
clip_image_f32_batch clone() const {
|
||||||
clip_image_f32_batch new_batch;
|
clip_image_f32_batch new_batch{
|
||||||
|
/* entries */ {},
|
||||||
|
/* is_audio */ is_audio,
|
||||||
|
/* grid_x */ grid_x,
|
||||||
|
/* grid_y */ grid_y,
|
||||||
|
};
|
||||||
new_batch.entries.reserve(entries.size());
|
new_batch.entries.reserve(entries.size());
|
||||||
for (const auto & entry : entries) {
|
for (const auto & entry : entries) {
|
||||||
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
||||||
|
|
@ -358,6 +396,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// debugging
|
||||||
|
//
|
||||||
|
|
||||||
|
static void print_tensor_shape(ggml_tensor * t) {
|
||||||
|
printf("%s.shape = [", t->name);
|
||||||
|
for (int i = 0; i < ggml_n_dims(t); ++i) {
|
||||||
|
printf("%" PRId64, t->ne[i]);
|
||||||
|
if (i < ggml_n_dims(t) - 1) {
|
||||||
|
printf(", ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("]\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
|
||||||
|
ggml_type type = t->type;
|
||||||
|
int64_t * ne = t->ne;
|
||||||
|
size_t * nb = t->nb;
|
||||||
|
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||||
|
printf("%s.data: [\n", t->name);
|
||||||
|
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||||
|
if (i2 == n && ne[2] > 2*n) {
|
||||||
|
printf(" ..., \n");
|
||||||
|
i2 = ne[2] - n;
|
||||||
|
}
|
||||||
|
printf(" [\n");
|
||||||
|
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
if (i1 == n && ne[1] > 2*n) {
|
||||||
|
printf(" ..., \n");
|
||||||
|
i1 = ne[1] - n;
|
||||||
|
}
|
||||||
|
printf(" [");
|
||||||
|
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
if (i0 == n && ne[0] > 2*n) {
|
||||||
|
printf("..., ");
|
||||||
|
i0 = ne[0] - n;
|
||||||
|
}
|
||||||
|
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||||
|
float v;
|
||||||
|
if (type == GGML_TYPE_F16) {
|
||||||
|
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||||
|
} else if (type == GGML_TYPE_F32) {
|
||||||
|
v = *(float *) &data[i];
|
||||||
|
} else if (type == GGML_TYPE_I32) {
|
||||||
|
v = (float) *(int32_t *) &data[i];
|
||||||
|
} else if (type == GGML_TYPE_I16) {
|
||||||
|
v = (float) *(int16_t *) &data[i];
|
||||||
|
} else if (type == GGML_TYPE_I8) {
|
||||||
|
v = (float) *(int8_t *) &data[i];
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
printf("%8.4f", v);
|
||||||
|
if (i0 < ne[0] - 1) printf(", ");
|
||||||
|
}
|
||||||
|
printf("],\n");
|
||||||
|
}
|
||||||
|
printf(" ],\n");
|
||||||
|
}
|
||||||
|
printf(" ]\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// API used internally with mtmd
|
// API used internally with mtmd
|
||||||
//
|
//
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,24 +1,9 @@
|
||||||
#ifndef CLIP_H
|
#pragma once
|
||||||
#define CLIP_H
|
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#ifdef LLAMA_SHARED
|
|
||||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
|
||||||
# ifdef LLAMA_BUILD
|
|
||||||
# define CLIP_API __declspec(dllexport)
|
|
||||||
# else
|
|
||||||
# define CLIP_API __declspec(dllimport)
|
|
||||||
# endif
|
|
||||||
# else
|
|
||||||
# define CLIP_API __attribute__ ((visibility ("default")))
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# define CLIP_API
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -34,102 +19,102 @@ struct clip_image_f32;
|
||||||
struct clip_image_u8_batch;
|
struct clip_image_u8_batch;
|
||||||
struct clip_image_f32_batch;
|
struct clip_image_f32_batch;
|
||||||
|
|
||||||
|
enum clip_modality {
|
||||||
|
CLIP_MODALITY_VISION,
|
||||||
|
CLIP_MODALITY_AUDIO,
|
||||||
|
};
|
||||||
|
|
||||||
struct clip_context_params {
|
struct clip_context_params {
|
||||||
bool use_gpu;
|
bool use_gpu;
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
};
|
};
|
||||||
|
|
||||||
// deprecated, use clip_init
|
struct clip_init_result {
|
||||||
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
|
struct clip_ctx * ctx_v; // vision context
|
||||||
|
struct clip_ctx * ctx_a; // audio context
|
||||||
|
};
|
||||||
|
|
||||||
CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
|
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params);
|
||||||
|
|
||||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
void clip_free(struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
|
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
|
||||||
|
|
||||||
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
||||||
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
||||||
CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
|
int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
// TODO: should be enum, not string
|
// TODO: should be enum, not string
|
||||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||||
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
|
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||||
"use clip_n_output_tokens instead");
|
|
||||||
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
|
|
||||||
"use clip_n_output_tokens instead");
|
|
||||||
|
|
||||||
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
|
||||||
|
|
||||||
// for M-RoPE, this will be the number of token positions in X and Y directions
|
// for M-RoPE, this will be the number of token positions in X and Y directions
|
||||||
// for other models, X will be the total number of tokens and Y will be 1
|
// for other models, X will be the total number of tokens and Y will be 1
|
||||||
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||||
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||||
|
|
||||||
// this should be equal to the embedding dimension of the text model
|
// this should be equal to the embedding dimension of the text model
|
||||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
struct clip_image_size * clip_image_size_init(void);
|
||||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
struct clip_image_u8 * clip_image_u8_init (void);
|
||||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
struct clip_image_f32 * clip_image_f32_init(void);
|
||||||
|
struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
|
||||||
CLIP_API struct clip_image_size * clip_image_size_init(void);
|
|
||||||
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
|
|
||||||
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
|
|
||||||
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
|
|
||||||
|
|
||||||
// nx, ny are the output image dimensions
|
// nx, ny are the output image dimensions
|
||||||
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||||
|
|
||||||
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
|
void clip_image_size_free (struct clip_image_size * img_size);
|
||||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
void clip_image_u8_free (struct clip_image_u8 * img);
|
||||||
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
void clip_image_f32_free(struct clip_image_f32 * img);
|
||||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||||
|
|
||||||
// use for accessing underlay data of clip_image_f32_batch
|
// use for accessing underlay data of clip_image_f32_batch
|
||||||
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
||||||
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
||||||
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
||||||
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
|
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
|
||||||
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
|
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
|
||||||
*/
|
*/
|
||||||
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
|
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
|
||||||
|
|
||||||
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||||
|
|
||||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||||
CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||||
|
|
||||||
/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
|
/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
|
||||||
CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
|
bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
|
||||||
|
|
||||||
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||||
CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||||
|
|
||||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||||
|
bool clip_is_glm(const struct clip_ctx * ctx);
|
||||||
|
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||||
|
bool clip_is_llava(const struct clip_ctx * ctx);
|
||||||
|
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||||
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
|
|
||||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
|
||||||
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
|
|
||||||
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
|
|
||||||
|
|
||||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
// use by audio input
|
||||||
|
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
|
||||||
|
|
||||||
|
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||||
|
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||||
|
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // CLIP_H
|
|
||||||
|
|
@ -1,591 +0,0 @@
|
||||||
#include "clip.h"
|
|
||||||
#include "llava.h"
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
#include "ggml-cpp.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cerrno>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <limits>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#if defined(LLAVA_LOG_OFF)
|
|
||||||
# define LOG_INF(...)
|
|
||||||
# define LOG_WRN(...)
|
|
||||||
# define LOG_ERR(...)
|
|
||||||
# define LOG_DBG(...)
|
|
||||||
#else // defined(LLAVA_LOG_OFF)
|
|
||||||
# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
|
||||||
# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
|
||||||
# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
|
||||||
# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
|
||||||
#endif // defined(LLAVA_LOG_OFF)
|
|
||||||
|
|
||||||
// RGB uint8 image
|
|
||||||
struct clip_image_u8 {
|
|
||||||
int nx;
|
|
||||||
int ny;
|
|
||||||
|
|
||||||
std::vector<uint8_t> buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
// RGB float32 image (NHWC)
|
|
||||||
// Memory layout: RGBRGBRGB...
|
|
||||||
struct clip_image_f32 {
|
|
||||||
int nx;
|
|
||||||
int ny;
|
|
||||||
|
|
||||||
std::vector<float> buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct clip_image_grid_shape {
|
|
||||||
int first;
|
|
||||||
int second;
|
|
||||||
};
|
|
||||||
|
|
||||||
// convenience cpp wrapper
|
|
||||||
struct clip_image_f32_batch_deleter {
|
|
||||||
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
|
|
||||||
};
|
|
||||||
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
|
|
||||||
|
|
||||||
struct clip_image_size_deleter {
|
|
||||||
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
|
|
||||||
};
|
|
||||||
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
|
||||||
*
|
|
||||||
* @param original_size The original size of the image in the format (width, height).
|
|
||||||
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
|
||||||
* @return The best fit resolution in the format (width, height).
|
|
||||||
*/
|
|
||||||
static std::pair<int, int> select_best_resolution(const std::pair<int, int>& original_size, const std::vector<std::pair<int, int>>& possible_resolutions) {
|
|
||||||
int original_width = original_size.first;
|
|
||||||
int original_height = original_size.second;
|
|
||||||
|
|
||||||
std::pair<int, int> best_fit;
|
|
||||||
int max_effective_resolution = 0;
|
|
||||||
int min_wasted_resolution = std::numeric_limits<int>::max();
|
|
||||||
|
|
||||||
for (const auto& resolution : possible_resolutions) {
|
|
||||||
int width = resolution.first;
|
|
||||||
int height = resolution.second;
|
|
||||||
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
|
|
||||||
int downscaled_width = static_cast<int>(original_width * scale);
|
|
||||||
int downscaled_height = static_cast<int>(original_height * scale);
|
|
||||||
int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
|
|
||||||
int wasted_resolution = (width * height) - effective_resolution;
|
|
||||||
// LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
|
|
||||||
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
|
|
||||||
max_effective_resolution = effective_resolution;
|
|
||||||
min_wasted_resolution = wasted_resolution;
|
|
||||||
best_fit = resolution;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return best_fit;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the anyres image grid shape object
|
|
||||||
*
|
|
||||||
* @param image_size
|
|
||||||
* @param grid_pinpoints
|
|
||||||
* @param image_patch_size
|
|
||||||
* @return <int, int>
|
|
||||||
*/
|
|
||||||
static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int> & image_size, const std::vector<std::pair<int, int>> & grid_pinpoints, int image_patch_size) {
|
|
||||||
/**
|
|
||||||
Conversion from gguf flat array to vector:
|
|
||||||
std::vector<std::pair<int, int>> possible_resolutions;
|
|
||||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
|
||||||
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
auto best_resolution = select_best_resolution(image_size, grid_pinpoints);
|
|
||||||
return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
|
|
||||||
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
|
|
||||||
struct {
|
|
||||||
struct ggml_context * ctx;
|
|
||||||
} model;
|
|
||||||
|
|
||||||
const int32_t image_size = clip_get_image_size(ctx_clip);
|
|
||||||
const int32_t patch_size = clip_get_patch_size(ctx_clip);
|
|
||||||
|
|
||||||
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
|
|
||||||
|
|
||||||
int num_patches_width = grid_shape.first; // grid 1-4
|
|
||||||
int num_patches_height = grid_shape.second; // grid 1-4
|
|
||||||
|
|
||||||
const size_t num_images = num_patches_width * num_patches_height + 1;
|
|
||||||
|
|
||||||
// TODO: size calculation is not calculated - it's only tens of MB
|
|
||||||
size_t ctx_size = 0;
|
|
||||||
|
|
||||||
{
|
|
||||||
ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features
|
|
||||||
ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_init_params params {
|
|
||||||
/*.mem_size =*/ ctx_size,
|
|
||||||
/*.mem_buffer =*/ NULL,
|
|
||||||
/*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API
|
|
||||||
};
|
|
||||||
|
|
||||||
// Python reference code for full unpad:
|
|
||||||
/*
|
|
||||||
base_image_feature = image_feature[0]
|
|
||||||
image_feature = image_feature[1:]
|
|
||||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
||||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
||||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
|
||||||
image_feature = torch.cat((
|
|
||||||
image_feature,
|
|
||||||
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)
|
|
||||||
), dim=-1)
|
|
||||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
||||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
|
||||||
*/
|
|
||||||
// We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval.
|
|
||||||
// In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet.
|
|
||||||
// Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them.
|
|
||||||
// Once all images are processed to prepended the base_image_features without any changes.
|
|
||||||
|
|
||||||
// Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling))
|
|
||||||
/*
|
|
||||||
image_feature = image_feature.view(2, 2, 24, 24, 4096)
|
|
||||||
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
image_feature = image_feature.view(2, 24, 2, 24, 4096)
|
|
||||||
image_feature = image_feature.flatten(0, 3)
|
|
||||||
|
|
||||||
// Reshape to 4D tensor by merging the last two dimensions
|
|
||||||
image_feature = image_feature.view(2, 2, 24, 24*4096)
|
|
||||||
image_feature = image_feature.permute(0, 2, 1, 3).contiguous()
|
|
||||||
image_feature = image_feature.view(-1, 4096)
|
|
||||||
*/
|
|
||||||
|
|
||||||
model.ctx = ggml_init(params);
|
|
||||||
|
|
||||||
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
|
|
||||||
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
|
|
||||||
// fill it with the image embeddings, ignoring the base
|
|
||||||
for (size_t i = 1; i < num_images; i++) {
|
|
||||||
size_t offset = (i-1) * clip_embd_nbytes(ctx_clip);
|
|
||||||
memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(model.ctx);
|
|
||||||
size_t size_ele = ggml_type_size(GGML_TYPE_F32);
|
|
||||||
|
|
||||||
struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features,
|
|
||||||
num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
|
|
||||||
num_patches_per_side,
|
|
||||||
num_patches_width,
|
|
||||||
num_patches_height,
|
|
||||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
|
|
||||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side,
|
|
||||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0);
|
|
||||||
// ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false);
|
|
||||||
struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3));
|
|
||||||
/**
|
|
||||||
At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings
|
|
||||||
image_feature = torch.cat((
|
|
||||||
image_feature,
|
|
||||||
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
|
|
||||||
), dim=-1)
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
// ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false);
|
|
||||||
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
|
|
||||||
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
|
|
||||||
ggml_build_forward_expand(gf, flatten);
|
|
||||||
|
|
||||||
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
|
||||||
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
|
|
||||||
ggml_backend_graph_compute(backend.get(), gf);
|
|
||||||
|
|
||||||
struct ggml_tensor* result = ggml_graph_node(gf, -1);
|
|
||||||
|
|
||||||
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
|
|
||||||
// append without newline tokens (default behavior in llava_arch when not using unpad ):
|
|
||||||
memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
|
|
||||||
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));
|
|
||||||
|
|
||||||
// Debug: Test single segments
|
|
||||||
// Current findings: sending base image, sending a segment embedding all works similar to python
|
|
||||||
// However, permuted embeddings do not work yet (stride issue?)
|
|
||||||
// memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context
|
|
||||||
// memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context
|
|
||||||
// *n_img_pos_out=576;
|
|
||||||
|
|
||||||
ggml_free(model.ctx);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
|
|
||||||
int width = image->nx;
|
|
||||||
int height = image->ny;
|
|
||||||
int num_patches = (height / patch_size) * (width / patch_size);
|
|
||||||
clip_image_f32 * patch = clip_image_f32_init();
|
|
||||||
patch->nx = patch_size * num_patches;
|
|
||||||
patch->ny = patch_size;
|
|
||||||
patch->buf.resize(3 * patch->nx * patch->ny);
|
|
||||||
|
|
||||||
int patch_index = 0;
|
|
||||||
|
|
||||||
for (int i = 0; i < height; i += patch_size) {
|
|
||||||
for (int j = 0; j < width; j += patch_size) {
|
|
||||||
for (int pi = 0; pi < patch_size; ++pi) {
|
|
||||||
for (int pj = 0; pj < patch_size; ++pj) {
|
|
||||||
int input_index = ((i + pi) * width + (j + pj)) * 3;
|
|
||||||
int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3;
|
|
||||||
patch->buf[output_index] = image->buf[input_index];
|
|
||||||
patch->buf[output_index+1] = image->buf[input_index+1];
|
|
||||||
patch->buf[output_index+2] = image->buf[input_index+2];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
patch_index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return patch;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
|
|
||||||
// std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
|
|
||||||
clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init());
|
|
||||||
if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) {
|
|
||||||
LOG_ERR("%s: unable to preprocess image\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t t_img_enc_start_us = ggml_time_us();
|
|
||||||
|
|
||||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
|
||||||
|
|
||||||
const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get());
|
|
||||||
|
|
||||||
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
|
|
||||||
std::vector<float *> image_embd_v;
|
|
||||||
image_embd_v.resize(n_imgs);
|
|
||||||
clip_image_size load_image_size;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n_imgs; i++) {
|
|
||||||
const int64_t t_img_enc_step_start_us = ggml_time_us();
|
|
||||||
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
|
|
||||||
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
|
|
||||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny));
|
|
||||||
int patch_size = 14;
|
|
||||||
load_image_size.width = nx;
|
|
||||||
load_image_size.height = ny;
|
|
||||||
clip_add_load_image_size(ctx_clip, &load_image_size);
|
|
||||||
|
|
||||||
bool encoded = false;
|
|
||||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
|
||||||
if (clip_is_qwen2vl(ctx_clip)) {
|
|
||||||
encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!encoded) {
|
|
||||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const int64_t t_img_enc_steop_batch_us = ggml_time_us();
|
|
||||||
LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
|
|
||||||
}
|
|
||||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
|
||||||
LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
|
||||||
|
|
||||||
int n_img_pos_out = 0;
|
|
||||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
|
||||||
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
|
|
||||||
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
|
|
||||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
|
||||||
std::memcpy(
|
|
||||||
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
|
|
||||||
image_embd_v[i],
|
|
||||||
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
|
|
||||||
n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
|
|
||||||
}
|
|
||||||
*n_img_pos = n_img_pos_out;
|
|
||||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
|
||||||
free(image_embd_v[i]);
|
|
||||||
}
|
|
||||||
image_embd_v.clear();
|
|
||||||
load_image_size.width = img->nx;
|
|
||||||
load_image_size.height = img->ny;
|
|
||||||
clip_add_load_image_size(ctx_clip, &load_image_size);
|
|
||||||
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height);
|
|
||||||
}
|
|
||||||
else if (clip_is_glm(ctx_clip)){
|
|
||||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
|
||||||
load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0);
|
|
||||||
load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0);
|
|
||||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
|
||||||
|
|
||||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
|
|
||||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
|
|
||||||
int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2);
|
|
||||||
*n_img_pos = (pos * pos + 2);
|
|
||||||
if (!encoded){
|
|
||||||
LOG_ERR("Unable to encode image \n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
|
||||||
// flat / default llava-1.5 type embedding
|
|
||||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
|
|
||||||
*n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
|
|
||||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
|
|
||||||
if (!encoded) {
|
|
||||||
LOG_ERR("Unable to encode image\n");
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// spatial_unpad llava-1.6 type embedding
|
|
||||||
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
|
|
||||||
std::vector<float *> image_embd_v;
|
|
||||||
image_embd_v.resize(n_imgs);
|
|
||||||
for (size_t i = 0; i < n_imgs; i++) {
|
|
||||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
|
||||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
|
|
||||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
|
||||||
if (!encoded) {
|
|
||||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
|
||||||
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
|
||||||
|
|
||||||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
|
||||||
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
|
|
||||||
|
|
||||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
|
||||||
for (size_t i = 0; i < num_gridpoints; i += 2) {
|
|
||||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
|
||||||
}
|
|
||||||
|
|
||||||
const int32_t image_size = clip_get_image_size(ctx_clip);
|
|
||||||
|
|
||||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
|
|
||||||
|
|
||||||
int n_img_pos_out;
|
|
||||||
clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
|
|
||||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
|
|
||||||
*n_img_pos = n_img_pos_out;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
|
||||||
free(image_embd_v[i]);
|
|
||||||
}
|
|
||||||
image_embd_v.clear();
|
|
||||||
|
|
||||||
// debug image/segment/normalization content:
|
|
||||||
// clip_image_u8 * tmp = clip_image_u8_init();
|
|
||||||
// clip_image_convert_f32_to_u8(*image_feature, *tmp);
|
|
||||||
// clip_image_save_to_bmp(*tmp, "image_feature.bmp");
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INF("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
|
|
||||||
|
|
||||||
const int64_t t_img_enc_end_us = ggml_time_us();
|
|
||||||
float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
|
||||||
|
|
||||||
LOG_INF("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
|
|
||||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
|
||||||
int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
|
||||||
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
|
|
||||||
if (n_image_embd != n_llama_embd) {
|
|
||||||
LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
|
||||||
// Granite vision uses up to 10 patches + base patch
|
|
||||||
int num_max_patches = 11;
|
|
||||||
if (clip_is_minicpmv(ctx_clip)) {
|
|
||||||
num_max_patches = 10;
|
|
||||||
}
|
|
||||||
if (clip_is_glm(ctx_clip)) {
|
|
||||||
num_max_patches = 1;
|
|
||||||
}
|
|
||||||
float * image_embd;
|
|
||||||
if (clip_is_qwen2vl(ctx_clip)) {
|
|
||||||
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
|
|
||||||
image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny));
|
|
||||||
} else {
|
|
||||||
image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
|
|
||||||
}
|
|
||||||
if (!image_embd) {
|
|
||||||
LOG_ERR("Unable to allocate memory for image embeddings\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_img_pos;
|
|
||||||
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
|
|
||||||
LOG_ERR("%s: cannot encode image, aborting\n", __func__);
|
|
||||||
free(image_embd);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*image_embd_out = image_embd;
|
|
||||||
*n_img_pos_out = n_img_pos;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llava_embd_batch {
|
|
||||||
std::vector<llama_pos> pos;
|
|
||||||
std::vector<int32_t> n_seq_id;
|
|
||||||
std::vector<llama_seq_id> seq_id_0;
|
|
||||||
std::vector<llama_seq_id *> seq_ids;
|
|
||||||
std::vector<int8_t> logits;
|
|
||||||
llama_batch batch;
|
|
||||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
|
||||||
pos .resize(n_tokens);
|
|
||||||
n_seq_id.resize(n_tokens);
|
|
||||||
seq_ids .resize(n_tokens + 1);
|
|
||||||
logits .resize(n_tokens);
|
|
||||||
seq_id_0.resize(1);
|
|
||||||
seq_id_0[0] = seq_id;
|
|
||||||
seq_ids [n_tokens] = nullptr;
|
|
||||||
batch = {
|
|
||||||
/*n_tokens =*/ n_tokens,
|
|
||||||
/*tokens =*/ nullptr,
|
|
||||||
/*embd =*/ embd,
|
|
||||||
/*pos =*/ pos.data(),
|
|
||||||
/*n_seq_id =*/ n_seq_id.data(),
|
|
||||||
/*seq_id =*/ seq_ids.data(),
|
|
||||||
/*logits =*/ logits.data(),
|
|
||||||
};
|
|
||||||
for (int i = 0; i < n_tokens; i++) {
|
|
||||||
batch.pos [i] = pos_0 + i;
|
|
||||||
batch.n_seq_id[i] = 1;
|
|
||||||
batch.seq_id [i] = seq_id_0.data();
|
|
||||||
batch.logits [i] = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
|
||||||
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
|
||||||
|
|
||||||
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
|
|
||||||
int n_eval = image_embed->n_image_pos - i;
|
|
||||||
if (n_eval > n_batch) {
|
|
||||||
n_eval = n_batch;
|
|
||||||
}
|
|
||||||
float * embd = image_embed->embed+i*n_embd;
|
|
||||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
|
||||||
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*n_past += n_eval;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) {
|
|
||||||
clip_image_u8 * img = clip_image_u8_init();
|
|
||||||
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
|
|
||||||
clip_image_u8_free(img);
|
|
||||||
LOG_ERR("%s: can't load image from bytes, is it a valid image?", __func__);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
float* image_embed = NULL;
|
|
||||||
int n_image_pos = 0;
|
|
||||||
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
|
|
||||||
if (!image_embed_result) {
|
|
||||||
clip_image_u8_free(img);
|
|
||||||
LOG_ERR("%s: couldn't embed the image\n", __func__);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
clip_image_u8_free(img);
|
|
||||||
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
|
|
||||||
result->embed = image_embed;
|
|
||||||
result->n_image_pos = n_image_pos;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) {
|
|
||||||
auto file = fopen(path, "rb");
|
|
||||||
if (file == NULL) {
|
|
||||||
LOG_ERR("%s: can't read file %s\n", __func__, path);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fseek(file, 0, SEEK_END);
|
|
||||||
auto fileSize = ftell(file);
|
|
||||||
fseek(file, 0, SEEK_SET);
|
|
||||||
|
|
||||||
auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
|
|
||||||
if (buffer == NULL) {
|
|
||||||
LOG_ERR("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
|
|
||||||
perror("Memory allocation error");
|
|
||||||
fclose(file);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
errno = 0;
|
|
||||||
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
|
|
||||||
if (ferror(file)) {
|
|
||||||
LOG_ERR("read error: %s", strerror(errno));
|
|
||||||
free(buffer);
|
|
||||||
fclose(file);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (ret != (size_t) fileSize) {
|
|
||||||
LOG_ERR("unexpectedly reached end of file");
|
|
||||||
free(buffer);
|
|
||||||
fclose(file);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
fclose(file); // Close the file
|
|
||||||
|
|
||||||
*bytesOut = buffer;
|
|
||||||
*sizeOut = fileSize;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
|
|
||||||
unsigned char* image_bytes;
|
|
||||||
long image_bytes_length;
|
|
||||||
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
|
|
||||||
if (!loaded) {
|
|
||||||
LOG_ERR("%s: failed to load %s\n", __func__, image_path);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
llava_image_embed *embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
|
|
||||||
free(image_bytes);
|
|
||||||
|
|
||||||
return embed;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llava_image_embed_free(struct llava_image_embed * embed) {
|
|
||||||
free(embed->embed);
|
|
||||||
free(embed);
|
|
||||||
}
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
#ifndef LLAVA_H
|
|
||||||
#define LLAVA_H
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#ifdef LLAMA_SHARED
|
|
||||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
|
||||||
# ifdef LLAMA_BUILD
|
|
||||||
# define LLAVA_API __declspec(dllexport)
|
|
||||||
# else
|
|
||||||
# define LLAVA_API __declspec(dllimport)
|
|
||||||
# endif
|
|
||||||
# else
|
|
||||||
# define LLAVA_API __attribute__ ((visibility ("default")))
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# define LLAVA_API
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct clip_ctx;
|
|
||||||
struct llava_image_embed {
|
|
||||||
float * embed;
|
|
||||||
int n_image_pos;
|
|
||||||
};
|
|
||||||
|
|
||||||
/** sanity check for clip <-> llava embed size match */
|
|
||||||
LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip);
|
|
||||||
|
|
||||||
LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
|
|
||||||
|
|
||||||
/** build an image embed from image file bytes */
|
|
||||||
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
|
||||||
/** build an image embed from a path to an image filename */
|
|
||||||
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
|
||||||
/** free an embedding made with llava_image_embed_make_* */
|
|
||||||
LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
|
||||||
|
|
||||||
/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
|
||||||
LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
@ -0,0 +1,769 @@
|
||||||
|
#include "mtmd-audio.h"
|
||||||
|
|
||||||
|
#define _USE_MATH_DEFINES // for M_PI
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
// most of the code here is copied from whisper.cpp
|
||||||
|
|
||||||
|
// align x to upper multiple of n
|
||||||
|
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||||
|
|
||||||
|
namespace whisper_preprocessor {
|
||||||
|
|
||||||
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
||||||
|
namespace {
|
||||||
|
struct whisper_global_cache {
|
||||||
|
// In FFT, we frequently use sine and cosine operations with the same values.
|
||||||
|
// We can use precalculated values to speed up the process.
|
||||||
|
float sin_vals[SIN_COS_N_COUNT];
|
||||||
|
float cos_vals[SIN_COS_N_COUNT];
|
||||||
|
|
||||||
|
// Hann window (Use cosf to eliminate difference)
|
||||||
|
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||||
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
||||||
|
float hann_window[WHISPER_N_FFT];
|
||||||
|
|
||||||
|
whisper_global_cache() {
|
||||||
|
fill_sin_cos_table();
|
||||||
|
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_sin_cos_table() {
|
||||||
|
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
||||||
|
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
||||||
|
sin_vals[i] = sinf(theta);
|
||||||
|
cos_vals[i] = cosf(theta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_hann_window(int length, bool periodic, float * output) {
|
||||||
|
int offset = -1;
|
||||||
|
if (periodic) {
|
||||||
|
offset = 0;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} global_cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
// naive Discrete Fourier Transform
|
||||||
|
// input is real-valued
|
||||||
|
// output is complex-valued
|
||||||
|
static void dft(const float* in, int N, float* out) {
|
||||||
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||||
|
|
||||||
|
for (int k = 0; k < N; k++) {
|
||||||
|
float re = 0;
|
||||||
|
float im = 0;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
||||||
|
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
||||||
|
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
out[k*2 + 0] = re;
|
||||||
|
out[k*2 + 1] = im;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cooley-Tukey FFT
|
||||||
|
// poor man's implementation - use something better
|
||||||
|
// input is real-valued
|
||||||
|
// output is complex-valued
|
||||||
|
static void fft(float* in, int N, float* out) {
|
||||||
|
if (N == 1) {
|
||||||
|
out[0] = in[0];
|
||||||
|
out[1] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int half_N = N / 2;
|
||||||
|
if (N - half_N*2 == 1) {
|
||||||
|
dft(in, N, out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* even = in + N;
|
||||||
|
for (int i = 0; i < half_N; ++i) {
|
||||||
|
even[i]= in[2*i];
|
||||||
|
}
|
||||||
|
float* even_fft = out + 2 * N;
|
||||||
|
fft(even, half_N, even_fft);
|
||||||
|
|
||||||
|
float* odd = even;
|
||||||
|
for (int i = 0; i < half_N; ++i) {
|
||||||
|
odd[i] = in[2*i + 1];
|
||||||
|
}
|
||||||
|
float* odd_fft = even_fft + N;
|
||||||
|
fft(odd, half_N, odd_fft);
|
||||||
|
|
||||||
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||||
|
for (int k = 0; k < half_N; k++) {
|
||||||
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||||
|
float re = global_cache.cos_vals[idx]; // cos(t)
|
||||||
|
float im = -global_cache.sin_vals[idx]; // sin(t)
|
||||||
|
|
||||||
|
float re_odd = odd_fft[2*k + 0];
|
||||||
|
float im_odd = odd_fft[2*k + 1];
|
||||||
|
|
||||||
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
||||||
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
||||||
|
|
||||||
|
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
||||||
|
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
||||||
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
||||||
|
const whisper_filters & filters, whisper_mel & mel) {
|
||||||
|
std::vector<float> fft_in(frame_size * 2, 0.0);
|
||||||
|
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
||||||
|
|
||||||
|
int n_fft = filters.n_fft;
|
||||||
|
int i = ith;
|
||||||
|
|
||||||
|
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
||||||
|
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
|
||||||
|
|
||||||
|
// calculate FFT only when fft_in are not all zero
|
||||||
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
||||||
|
const int offset = i * frame_step;
|
||||||
|
|
||||||
|
// apply Hann window (~10% faster)
|
||||||
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
||||||
|
fft_in[j] = hann[j] * samples[offset + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill the rest with zeros
|
||||||
|
if (n_samples - offset < frame_size) {
|
||||||
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FFT
|
||||||
|
fft(fft_in.data(), frame_size, fft_out.data());
|
||||||
|
|
||||||
|
// Calculate modulus^2 of complex numbers
|
||||||
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||||
|
for (int j = 0; j < n_fft; j++) {
|
||||||
|
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mel spectrogram
|
||||||
|
for (int j = 0; j < mel.n_mel; j++) {
|
||||||
|
double sum = 0.0;
|
||||||
|
// unroll loop (suggested by GH user @lunixbochs)
|
||||||
|
int k = 0;
|
||||||
|
for (k = 0; k < n_fft - 3; k += 4) {
|
||||||
|
sum +=
|
||||||
|
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
|
||||||
|
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
|
||||||
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
||||||
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
||||||
|
}
|
||||||
|
// handle n_fft remainder
|
||||||
|
for (; k < n_fft; k++) {
|
||||||
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
||||||
|
}
|
||||||
|
sum = log10(std::max(sum, 1e-10));
|
||||||
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise fft_out are all zero
|
||||||
|
double sum = log10(1e-10);
|
||||||
|
for (; i < mel.n_len; i += n_threads) {
|
||||||
|
for (int j = 0; j < mel.n_mel; j++) {
|
||||||
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
||||||
|
static bool log_mel_spectrogram(
|
||||||
|
const float * samples,
|
||||||
|
const int n_samples,
|
||||||
|
const int /*sample_rate*/,
|
||||||
|
const int frame_size,
|
||||||
|
const int frame_step,
|
||||||
|
const int n_mel,
|
||||||
|
const int n_threads,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
const bool debug,
|
||||||
|
whisper_mel & mel) {
|
||||||
|
//const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Hann window
|
||||||
|
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
||||||
|
const float * hann = global_cache.hann_window;
|
||||||
|
|
||||||
|
// Calculate the length of padding
|
||||||
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
||||||
|
int64_t stage_2_pad = frame_size / 2;
|
||||||
|
|
||||||
|
// Initialize a vector and copy data from C array to it.
|
||||||
|
std::vector<float> samples_padded;
|
||||||
|
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
||||||
|
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
||||||
|
|
||||||
|
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
||||||
|
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
||||||
|
|
||||||
|
// reflective pad 200 samples at the beginning of audio
|
||||||
|
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
||||||
|
|
||||||
|
mel.n_mel = n_mel;
|
||||||
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
||||||
|
// Calculate number of frames + remove the last frame
|
||||||
|
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
|
||||||
|
// Calculate semi-padded sample length to ensure compatibility
|
||||||
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
||||||
|
mel.data.resize(mel.n_mel * mel.n_len);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<std::thread> workers(n_threads - 1);
|
||||||
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
|
workers[iw] = std::thread(
|
||||||
|
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
||||||
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
||||||
|
std::cref(filters), std::ref(mel));
|
||||||
|
}
|
||||||
|
|
||||||
|
// main thread
|
||||||
|
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
||||||
|
|
||||||
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
|
workers[iw].join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clamping and normalization
|
||||||
|
double mmax = -1e20;
|
||||||
|
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||||
|
if (mel.data[i] > mmax) {
|
||||||
|
mmax = mel.data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mmax -= 8.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||||
|
if (mel.data[i] < mmax) {
|
||||||
|
mel.data[i] = mmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump log_mel_spectrogram
|
||||||
|
if (debug) {
|
||||||
|
std::ofstream outFile("log_mel_spectrogram.json");
|
||||||
|
outFile << "[";
|
||||||
|
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
||||||
|
outFile << mel.data[i] << ", ";
|
||||||
|
}
|
||||||
|
outFile << mel.data[mel.data.size() - 1] << "]";
|
||||||
|
outFile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool preprocess_audio(
|
||||||
|
const float * samples,
|
||||||
|
size_t n_samples,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
std::vector<whisper_mel> & output) {
|
||||||
|
|
||||||
|
if (n_samples == 0) {
|
||||||
|
// empty audio
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_mel out_full;
|
||||||
|
bool ok = log_mel_spectrogram(
|
||||||
|
samples,
|
||||||
|
n_samples,
|
||||||
|
COMMON_SAMPLE_RATE,
|
||||||
|
WHISPER_N_FFT,
|
||||||
|
WHISPER_HOP_LENGTH,
|
||||||
|
filters.n_mel,
|
||||||
|
4, // n_threads
|
||||||
|
filters,
|
||||||
|
false, // debug
|
||||||
|
out_full);
|
||||||
|
if (!ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
|
||||||
|
// we always expect the mel to have 3000 silent frames at the end
|
||||||
|
// printf("n_len %d\n", out_full.n_len);
|
||||||
|
const size_t frames_per_chunk = 3000;
|
||||||
|
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
|
||||||
|
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
|
||||||
|
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
|
||||||
|
if ((size_t)n_len < frames_per_chunk) {
|
||||||
|
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_mel out_chunk;
|
||||||
|
out_chunk.n_len = n_len;
|
||||||
|
out_chunk.n_mel = out_full.n_mel;
|
||||||
|
out_chunk.n_len_org = out_full.n_mel; // unused
|
||||||
|
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
|
||||||
|
|
||||||
|
for (int i = 0; i < out_full.n_mel; i++) {
|
||||||
|
auto src = out_full.data.begin() + i*out_full.n_len + off;
|
||||||
|
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_back(std::move(out_chunk));
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace whisper_preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
// precalculated mel filter banks
|
||||||
|
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
|
||||||
|
//
|
||||||
|
// generated from python code:
|
||||||
|
//
|
||||||
|
// from numpy import load
|
||||||
|
// data = load('mel_filters.npz')
|
||||||
|
// lst = data.files
|
||||||
|
// for item in lst:
|
||||||
|
// print(item)
|
||||||
|
// print(data[item].shape)
|
||||||
|
// n_mel = data[item].shape[0]
|
||||||
|
// n_fft = data[item].shape[1]
|
||||||
|
// for i, row in enumerate(data[item]):
|
||||||
|
// for j, val in enumerate(row):
|
||||||
|
// val = val * 1000.0
|
||||||
|
// if val != 0:
|
||||||
|
// print(f"data[{i*n_fft + j}] = {val:.6f};")
|
||||||
|
|
||||||
|
namespace whisper_precalc_filters {
|
||||||
|
|
||||||
|
whisper_preprocessor::whisper_filters get_128_bins() {
|
||||||
|
whisper_preprocessor::whisper_filters filters;
|
||||||
|
filters.n_mel = 128;
|
||||||
|
filters.n_fft = 201;
|
||||||
|
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
|
||||||
|
|
||||||
|
data[1] = 12.37398665;
|
||||||
|
data[202] = 30.39256483;
|
||||||
|
data[404] = 24.74797331;
|
||||||
|
data[605] = 18.01857911;
|
||||||
|
data[807] = 37.12195903;
|
||||||
|
data[1008] = 5.64459199;
|
||||||
|
data[1009] = 6.72939420;
|
||||||
|
data[1210] = 36.03715822;
|
||||||
|
data[1412] = 19.10337992;
|
||||||
|
data[1613] = 23.66316877;
|
||||||
|
data[1815] = 31.47736564;
|
||||||
|
data[2016] = 11.28918398;
|
||||||
|
data[2017] = 1.08480197;
|
||||||
|
data[2218] = 41.68175161;
|
||||||
|
data[2420] = 13.45878839;
|
||||||
|
data[2621] = 29.30776216;
|
||||||
|
data[2823] = 25.83277412;
|
||||||
|
data[3024] = 16.93377644;
|
||||||
|
data[3226] = 38.20675984;
|
||||||
|
data[3427] = 4.55979025;
|
||||||
|
data[3428] = 7.81419594;
|
||||||
|
data[3629] = 34.95235741;
|
||||||
|
data[3831] = 20.18818259;
|
||||||
|
data[4032] = 22.57836796;
|
||||||
|
data[4234] = 32.56217018;
|
||||||
|
data[4435] = 10.20438317;
|
||||||
|
data[4436] = 2.16960395;
|
||||||
|
data[4637] = 40.59694707;
|
||||||
|
data[4839] = 14.54358920;
|
||||||
|
data[5040] = 28.22295949;
|
||||||
|
data[5242] = 26.91757679;
|
||||||
|
data[5443] = 15.84897563;
|
||||||
|
data[5645] = 39.29156065;
|
||||||
|
data[5846] = 3.47498828;
|
||||||
|
data[5847] = 8.89899861;
|
||||||
|
data[6048] = 33.86755288;
|
||||||
|
data[6250] = 21.27298526;
|
||||||
|
data[6451] = 21.49356715;
|
||||||
|
data[6653] = 33.64697099;
|
||||||
|
data[6854] = 9.11958050;
|
||||||
|
data[6855] = 3.25440569;
|
||||||
|
data[7056] = 39.51214626;
|
||||||
|
data[7258] = 15.62839188;
|
||||||
|
data[7459] = 27.13815868;
|
||||||
|
data[7661] = 28.00237760;
|
||||||
|
data[7862] = 14.76417296;
|
||||||
|
data[8064] = 40.37636518;
|
||||||
|
data[8265] = 2.38068704;
|
||||||
|
data[8266] = 10.20263787;
|
||||||
|
data[8467] = 31.61146119;
|
||||||
|
data[8669] = 24.54700135;
|
||||||
|
data[8870] = 15.32919332;
|
||||||
|
data[8871] = 1.66583748;
|
||||||
|
data[9072] = 36.72905266;
|
||||||
|
data[9274] = 20.09709924;
|
||||||
|
data[9475] = 16.93102531;
|
||||||
|
data[9476] = 2.90265540;
|
||||||
|
data[9677] = 32.84499049;
|
||||||
|
data[9879] = 23.52004871;
|
||||||
|
data[10080] = 11.03894413;
|
||||||
|
data[10081] = 10.72582975;
|
||||||
|
data[10282] = 22.71829173;
|
||||||
|
data[10484] = 32.27872774;
|
||||||
|
data[10685] = 0.11626833;
|
||||||
|
data[10686] = 22.85348251;
|
||||||
|
data[10887] = 8.56344029;
|
||||||
|
data[10888] = 14.97978810;
|
||||||
|
data[11089] = 15.51398356;
|
||||||
|
data[11090] = 8.51490628;
|
||||||
|
data[11291] = 21.10680379;
|
||||||
|
data[11292] = 3.32652032;
|
||||||
|
data[11493] = 25.47064796;
|
||||||
|
data[11695] = 27.35907957;
|
||||||
|
data[11896] = 0.65853616;
|
||||||
|
data[11897] = 23.83812517;
|
||||||
|
data[12098] = 3.44359246;
|
||||||
|
data[12099] = 21.22455277;
|
||||||
|
data[12300] = 5.35842171;
|
||||||
|
data[12301] = 19.42555793;
|
||||||
|
data[12502] = 6.49324711;
|
||||||
|
data[12503] = 18.35542172;
|
||||||
|
data[12704] = 6.93138083;
|
||||||
|
data[12705] = 17.93504693;
|
||||||
|
data[12906] = 6.74968259;
|
||||||
|
data[12907] = 18.09151843;
|
||||||
|
data[13108] = 6.01899112;
|
||||||
|
data[13109] = 18.75767298;
|
||||||
|
data[13310] = 4.80452832;
|
||||||
|
data[13311] = 19.87172849;
|
||||||
|
data[13512] = 3.16627859;
|
||||||
|
data[13513] = 21.37690969;
|
||||||
|
data[13514] = 1.25317345;
|
||||||
|
data[13714] = 1.15934468;
|
||||||
|
data[13715] = 20.80361731;
|
||||||
|
data[13716] = 4.04486805;
|
||||||
|
data[13917] = 17.55363122;
|
||||||
|
data[13918] = 7.08320038;
|
||||||
|
data[14119] = 14.07538634;
|
||||||
|
data[14120] = 10.32655034;
|
||||||
|
data[14321] = 10.40921453;
|
||||||
|
data[14322] = 13.73696327;
|
||||||
|
data[14523] = 6.59187697;
|
||||||
|
data[14524] = 17.27988198;
|
||||||
|
data[14525] = 1.46804214;
|
||||||
|
data[14725] = 2.65681883;
|
||||||
|
data[14726] = 18.09193194;
|
||||||
|
data[14727] = 5.85655728;
|
||||||
|
data[14928] = 13.34277913;
|
||||||
|
data[14929] = 10.28267574;
|
||||||
|
data[15130] = 8.56800377;
|
||||||
|
data[15131] = 14.72230814;
|
||||||
|
data[15132] = 1.04039861;
|
||||||
|
data[15332] = 3.79085587;
|
||||||
|
data[15333] = 17.14678481;
|
||||||
|
data[15334] = 6.11609267;
|
||||||
|
data[15535] = 11.75929047;
|
||||||
|
data[15536] = 11.13393717;
|
||||||
|
data[15737] = 6.43857848;
|
||||||
|
data[15738] = 16.07806236;
|
||||||
|
data[15739] = 4.23917221;
|
||||||
|
data[15939] = 1.19989377;
|
||||||
|
data[15940] = 12.75671553;
|
||||||
|
data[15941] = 9.65298992;
|
||||||
|
data[16142] = 7.06935255;
|
||||||
|
data[16143] = 14.94054683;
|
||||||
|
data[16144] = 4.19024844;
|
||||||
|
data[16344] = 1.51483389;
|
||||||
|
data[16345] = 12.00899947;
|
||||||
|
data[16346] = 9.84823331;
|
||||||
|
data[16547] = 6.10224018;
|
||||||
|
data[16548] = 15.33857174;
|
||||||
|
data[16549] = 5.57676842;
|
||||||
|
data[16749] = 0.36827257;
|
||||||
|
data[16750] = 9.89749376;
|
||||||
|
data[16751] = 11.35340426;
|
||||||
|
data[16752] = 2.05122307;
|
||||||
|
data[16952] = 3.89297144;
|
||||||
|
data[16953] = 12.97352277;
|
||||||
|
data[16954] = 8.06631614;
|
||||||
|
data[17155] = 6.74493238;
|
||||||
|
data[17156] = 13.85874674;
|
||||||
|
data[17157] = 5.41190524;
|
||||||
|
data[17357] = 0.74220158;
|
||||||
|
data[17358] = 8.98779090;
|
||||||
|
data[17359] = 11.37871388;
|
||||||
|
data[17360] = 3.32958088;
|
||||||
|
data[17560] = 2.82313535;
|
||||||
|
data[17561] = 10.68049297;
|
||||||
|
data[17562] = 9.43340641;
|
||||||
|
data[17563] = 1.76325557;
|
||||||
|
data[17763] = 4.39018616;
|
||||||
|
data[17764] = 11.87758986;
|
||||||
|
data[17765] = 7.97005836;
|
||||||
|
data[17766] = 0.66104700;
|
||||||
|
data[17966] = 5.49466675;
|
||||||
|
data[17967] = 12.62953598;
|
||||||
|
data[17968] = 6.93987962;
|
||||||
|
data[18169] = 6.18401915;
|
||||||
|
data[18170] = 12.93473132;
|
||||||
|
data[18171] = 6.29778765;
|
||||||
|
data[18371] = 0.02325210;
|
||||||
|
data[18372] = 6.50206627;
|
||||||
|
data[18373] = 12.32661773;
|
||||||
|
data[18374] = 6.00216538;
|
||||||
|
data[18574] = 0.31548753;
|
||||||
|
data[18575] = 6.48925547;
|
||||||
|
data[18576] = 12.04130240;
|
||||||
|
data[18577] = 6.01462880;
|
||||||
|
data[18777] = 0.29979556;
|
||||||
|
data[18778] = 6.18288014;
|
||||||
|
data[18779] = 12.04272825;
|
||||||
|
data[18780] = 6.29981188;
|
||||||
|
data[18781] = 0.55689598;
|
||||||
|
data[18980] = 0.01120471;
|
||||||
|
data[18981] = 5.61729167;
|
||||||
|
data[18982] = 11.22337859;
|
||||||
|
data[18983] = 6.82516303;
|
||||||
|
data[18984] = 1.35264499;
|
||||||
|
data[19184] = 4.82410006;
|
||||||
|
data[19185] = 10.16623247;
|
||||||
|
data[19186] = 7.56075513;
|
||||||
|
data[19187] = 2.34590308;
|
||||||
|
data[19387] = 3.83235747;
|
||||||
|
data[19388] = 8.92296247;
|
||||||
|
data[19389] = 8.47910438;
|
||||||
|
data[19390] = 3.50978645;
|
||||||
|
data[19590] = 2.66873185;
|
||||||
|
data[19591] = 7.51965167;
|
||||||
|
data[19592] = 9.55500547;
|
||||||
|
data[19593] = 4.81966138;
|
||||||
|
data[19594] = 0.08431751;
|
||||||
|
data[19793] = 1.35767367;
|
||||||
|
data[19794] = 5.98019501;
|
||||||
|
data[19795] = 10.60271543;
|
||||||
|
data[19796] = 6.25298498;
|
||||||
|
data[19797] = 1.74059917;
|
||||||
|
data[19997] = 4.32644226;
|
||||||
|
data[19998] = 8.73131864;
|
||||||
|
data[19999] = 7.78916525;
|
||||||
|
data[20000] = 3.48923868;
|
||||||
|
data[20200] = 2.57835095;
|
||||||
|
data[20201] = 6.77582854;
|
||||||
|
data[20202] = 9.40941647;
|
||||||
|
data[20203] = 5.31194592;
|
||||||
|
data[20204] = 1.21447595;
|
||||||
|
data[20403] = 0.75411191;
|
||||||
|
data[20404] = 4.75395704;
|
||||||
|
data[20405] = 8.75380263;
|
||||||
|
data[20406] = 7.19209015;
|
||||||
|
data[20407] = 3.28754401;
|
||||||
|
data[20607] = 2.68179690;
|
||||||
|
data[20608] = 6.49331464;
|
||||||
|
data[20609] = 9.11457930;
|
||||||
|
data[20610] = 5.39387390;
|
||||||
|
data[20611] = 1.67316827;
|
||||||
|
data[20810] = 0.57394296;
|
||||||
|
data[20811] = 4.20600036;
|
||||||
|
data[20812] = 7.83805829;
|
||||||
|
data[20813] = 7.52023002;
|
||||||
|
data[20814] = 3.97470826;
|
||||||
|
data[20815] = 0.42918732;
|
||||||
|
data[21014] = 1.90464477;
|
||||||
|
data[21015] = 5.36569161;
|
||||||
|
data[21016] = 8.82673822;
|
||||||
|
data[21017] = 6.27609482;
|
||||||
|
data[21018] = 2.89750961;
|
||||||
|
data[21218] = 2.89885257;
|
||||||
|
data[21219] = 6.19694078;
|
||||||
|
data[21220] = 8.56699049;
|
||||||
|
data[21221] = 5.34748193;
|
||||||
|
data[21222] = 2.12797290;
|
||||||
|
data[21421] = 0.44750227;
|
||||||
|
data[21422] = 3.59030394;
|
||||||
|
data[21423] = 6.73310598;
|
||||||
|
data[21424] = 7.77023612;
|
||||||
|
data[21425] = 4.70231380;
|
||||||
|
data[21426] = 1.63439126;
|
||||||
|
data[21625] = 1.01536023;
|
||||||
|
data[21626] = 4.01018746;
|
||||||
|
data[21627] = 7.00501446;
|
||||||
|
data[21628] = 7.23442994;
|
||||||
|
data[21629] = 4.31095669;
|
||||||
|
data[21630] = 1.38748321;
|
||||||
|
data[21829] = 1.33348850;
|
||||||
|
data[21830] = 4.18730825;
|
||||||
|
data[21831] = 7.04112789;
|
||||||
|
data[21832] = 6.93188375;
|
||||||
|
data[21833] = 4.14605811;
|
||||||
|
data[21834] = 1.36023236;
|
||||||
|
data[22033] = 1.42879714;
|
||||||
|
data[22034] = 4.14824858;
|
||||||
|
data[22035] = 6.86769979;
|
||||||
|
data[22036] = 6.83705276;
|
||||||
|
data[22037] = 4.18239459;
|
||||||
|
data[22038] = 1.52773573;
|
||||||
|
data[22237] = 1.32610439;
|
||||||
|
data[22238] = 3.91751388;
|
||||||
|
data[22239] = 6.50892360;
|
||||||
|
data[22240] = 6.92639686;
|
||||||
|
data[22241] = 4.39672917;
|
||||||
|
data[22242] = 1.86706171;
|
||||||
|
data[22441] = 1.04827771;
|
||||||
|
data[22442] = 3.51767405;
|
||||||
|
data[22443] = 5.98707050;
|
||||||
|
data[22444] = 7.17824046;
|
||||||
|
data[22445] = 4.76767914;
|
||||||
|
data[22446] = 2.35711760;
|
||||||
|
data[22645] = 0.61636406;
|
||||||
|
data[22646] = 2.96949223;
|
||||||
|
data[22647] = 5.32262027;
|
||||||
|
data[22648] = 7.57265091;
|
||||||
|
data[22649] = 5.27558755;
|
||||||
|
data[22650] = 2.97852419;
|
||||||
|
data[22651] = 0.68146095;
|
||||||
|
data[22849] = 0.04971400;
|
||||||
|
data[22850] = 2.29204819;
|
||||||
|
data[22851] = 4.53438237;
|
||||||
|
data[22852] = 6.77671656;
|
||||||
|
data[22853] = 5.90240723;
|
||||||
|
data[22854] = 3.71349836;
|
||||||
|
data[22855] = 1.52458926;
|
||||||
|
data[23054] = 1.50285335;
|
||||||
|
data[23055] = 3.63961048;
|
||||||
|
data[23056] = 5.77636715;
|
||||||
|
data[23057] = 6.63159089;
|
||||||
|
data[23058] = 4.54574358;
|
||||||
|
data[23059] = 2.45989650;
|
||||||
|
data[23060] = 0.37404924;
|
||||||
|
data[23258] = 0.61795861;
|
||||||
|
data[23259] = 2.65410915;
|
||||||
|
data[23260] = 4.69025923;
|
||||||
|
data[23261] = 6.72641024;
|
||||||
|
data[23262] = 5.46034705;
|
||||||
|
data[23263] = 3.47270933;
|
||||||
|
data[23264] = 1.48507138;
|
||||||
|
data[23463] = 1.59233576;
|
||||||
|
data[23464] = 3.53261665;
|
||||||
|
data[23465] = 5.47289755;
|
||||||
|
data[23466] = 6.44368259;
|
||||||
|
data[23467] = 4.54962999;
|
||||||
|
data[23468] = 2.65557761;
|
||||||
|
data[23469] = 0.76152512;
|
||||||
|
data[23667] = 0.46749352;
|
||||||
|
data[23668] = 2.31641904;
|
||||||
|
data[23669] = 4.16534441;
|
||||||
|
data[23670] = 6.01426978;
|
||||||
|
data[23671] = 5.67844696;
|
||||||
|
data[23672] = 3.87357362;
|
||||||
|
data[23673] = 2.06870004;
|
||||||
|
data[23674] = 0.26382666;
|
||||||
|
data[23872] = 1.05349103;
|
||||||
|
data[23873] = 2.81536230;
|
||||||
|
data[23874] = 4.57723346;
|
||||||
|
data[23875] = 6.33910485;
|
||||||
|
data[23876] = 5.12815686;
|
||||||
|
data[23877] = 3.40826320;
|
||||||
|
data[23878] = 1.68837002;
|
||||||
|
data[24077] = 1.43350090;
|
||||||
|
data[24078] = 3.11241671;
|
||||||
|
data[24079] = 4.79133241;
|
||||||
|
data[24080] = 6.40943693;
|
||||||
|
data[24081] = 4.77052201;
|
||||||
|
data[24082] = 3.13160778;
|
||||||
|
data[24083] = 1.49269309;
|
||||||
|
data[24281] = 0.02932359;
|
||||||
|
data[24282] = 1.62918994;
|
||||||
|
data[24283] = 3.22905602;
|
||||||
|
data[24284] = 4.82892245;
|
||||||
|
data[24285] = 6.14671456;
|
||||||
|
data[24286] = 4.58496623;
|
||||||
|
data[24287] = 3.02321767;
|
||||||
|
data[24288] = 1.46146910;
|
||||||
|
data[24486] = 0.13601698;
|
||||||
|
data[24487] = 1.66055572;
|
||||||
|
data[24488] = 3.18509457;
|
||||||
|
data[24489] = 4.70963307;
|
||||||
|
data[24490] = 6.04072399;
|
||||||
|
data[24491] = 4.55250870;
|
||||||
|
data[24492] = 3.06429295;
|
||||||
|
data[24493] = 1.57607743;
|
||||||
|
data[24494] = 0.08786193;
|
||||||
|
data[24691] = 0.09328097;
|
||||||
|
data[24692] = 1.54603878;
|
||||||
|
data[24693] = 2.99879676;
|
||||||
|
data[24694] = 4.45155473;
|
||||||
|
data[24695] = 5.90431225;
|
||||||
|
data[24696] = 4.65566106;
|
||||||
|
data[24697] = 3.23751615;
|
||||||
|
data[24698] = 1.81937125;
|
||||||
|
data[24699] = 0.40122634;
|
||||||
|
data[24897] = 1.30262633;
|
||||||
|
data[24898] = 2.68698297;
|
||||||
|
data[24899] = 4.07133950;
|
||||||
|
data[24900] = 5.45569602;
|
||||||
|
data[24901] = 4.87832492;
|
||||||
|
data[24902] = 3.52695142;
|
||||||
|
data[24903] = 2.17557792;
|
||||||
|
data[24904] = 0.82420459;
|
||||||
|
data[25102] = 0.94595028;
|
||||||
|
data[25103] = 2.26512621;
|
||||||
|
data[25104] = 3.58430226;
|
||||||
|
data[25105] = 4.90347855;
|
||||||
|
data[25106] = 5.20569785;
|
||||||
|
data[25107] = 3.91795207;
|
||||||
|
data[25108] = 2.63020652;
|
||||||
|
data[25109] = 1.34246063;
|
||||||
|
data[25110] = 0.05471494;
|
||||||
|
data[25307] = 0.49037894;
|
||||||
|
data[25308] = 1.74744334;
|
||||||
|
data[25309] = 3.00450763;
|
||||||
|
data[25310] = 4.26157191;
|
||||||
|
data[25311] = 5.51863620;
|
||||||
|
data[25312] = 4.39707236;
|
||||||
|
data[25313] = 3.16995848;
|
||||||
|
data[25314] = 1.94284460;
|
||||||
|
data[25315] = 0.71573065;
|
||||||
|
data[25513] = 1.14698056;
|
||||||
|
data[25514] = 2.34485767;
|
||||||
|
data[25515] = 3.54273478;
|
||||||
|
data[25516] = 4.74061165;
|
||||||
|
data[25517] = 4.95198462;
|
||||||
|
data[25518] = 3.78264743;
|
||||||
|
data[25519] = 2.61331047;
|
||||||
|
data[25520] = 1.44397374;
|
||||||
|
data[25521] = 0.27463681;
|
||||||
|
data[25718] = 0.47569509;
|
||||||
|
data[25719] = 1.61717169;
|
||||||
|
data[25720] = 2.75864848;
|
||||||
|
data[25721] = 3.90012516;
|
||||||
|
data[25722] = 5.04160160;
|
||||||
|
data[25723] = 4.45712078;
|
||||||
|
data[25724] = 3.34284059;
|
||||||
|
data[25725] = 2.22856039;
|
||||||
|
data[25726] = 1.11428020;
|
||||||
|
|
||||||
|
for (auto & val : data) {
|
||||||
|
val /= 1000.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
filters.data = std::move(data);
|
||||||
|
return filters;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace whisper_precalc_filters
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#define WHISPER_ASSERT GGML_ASSERT
|
||||||
|
|
||||||
|
#define WHISPER_SAMPLE_RATE 16000
|
||||||
|
#define WHISPER_N_FFT 400
|
||||||
|
#define WHISPER_HOP_LENGTH 160
|
||||||
|
#define WHISPER_CHUNK_SIZE 30
|
||||||
|
|
||||||
|
#define COMMON_SAMPLE_RATE 16000
|
||||||
|
|
||||||
|
namespace whisper_preprocessor {
|
||||||
|
|
||||||
|
struct whisper_mel {
|
||||||
|
int n_len;
|
||||||
|
int n_len_org;
|
||||||
|
int n_mel;
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct whisper_filters {
|
||||||
|
int32_t n_mel;
|
||||||
|
int32_t n_fft;
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool preprocess_audio(
|
||||||
|
const float * samples,
|
||||||
|
size_t n_samples,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
std::vector<whisper_mel> & output);
|
||||||
|
|
||||||
|
} // namespace whisper_preprocessor
|
||||||
|
|
||||||
|
namespace whisper_precalc_filters {
|
||||||
|
|
||||||
|
whisper_preprocessor::whisper_filters get_128_bins();
|
||||||
|
|
||||||
|
} // namespace whisper_precalc_filters
|
||||||
|
|
@ -0,0 +1,460 @@
|
||||||
|
// fix problem with std::min and std::max
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
# define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "mtmd.h"
|
||||||
|
#include "mtmd-helper.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//#define MTMD_AUDIO_DEBUG
|
||||||
|
|
||||||
|
#define MINIAUDIO_IMPLEMENTATION
|
||||||
|
#ifndef MTMD_AUDIO_DEBUG
|
||||||
|
# define MA_NO_ENCODING
|
||||||
|
#endif
|
||||||
|
#define MA_NO_DEVICE_IO
|
||||||
|
#define MA_NO_RESOURCE_MANAGER
|
||||||
|
#define MA_NO_NODE_GRAPH
|
||||||
|
#define MA_NO_ENGINE
|
||||||
|
#define MA_NO_GENERATION
|
||||||
|
#define MA_API static
|
||||||
|
#include "miniaudio/miniaudio.h"
|
||||||
|
|
||||||
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
|
#include "stb/stb_image.h"
|
||||||
|
|
||||||
|
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
|
||||||
|
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
|
||||||
|
|
||||||
|
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
|
||||||
|
size_t n_tokens = 0;
|
||||||
|
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
|
||||||
|
auto chunk = mtmd_input_chunks_get(chunks, i);
|
||||||
|
n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
|
||||||
|
}
|
||||||
|
return n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
|
||||||
|
llama_pos n_pos = 0;
|
||||||
|
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
|
||||||
|
auto chunk = mtmd_input_chunks_get(chunks, i);
|
||||||
|
n_pos += mtmd_input_chunk_get_n_pos(chunk);
|
||||||
|
}
|
||||||
|
return n_pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper struct to make working with embd batch easier
|
||||||
|
// note: this will be removed after llama_batch_ext refactoring
|
||||||
|
struct decode_embd_batch {
|
||||||
|
int n_pos_per_embd;
|
||||||
|
int n_mmproj_embd;
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<llama_pos> pos_view; // used by mrope
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id> seq_id_0;
|
||||||
|
std::vector<llama_seq_id *> seq_ids;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
llama_batch batch;
|
||||||
|
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
|
||||||
|
pos .resize(n_tokens * n_pos_per_embd);
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_ids .resize(n_tokens + 1);
|
||||||
|
logits .resize(n_tokens);
|
||||||
|
seq_id_0.resize(1);
|
||||||
|
seq_ids [n_tokens] = nullptr;
|
||||||
|
batch = {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ embd,
|
||||||
|
/*pos =*/ pos.data(),
|
||||||
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
|
/*seq_id =*/ seq_ids.data(),
|
||||||
|
/*logits =*/ logits.data(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.pos [i] = pos_0 + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// M-RoPE for image
|
||||||
|
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
|
||||||
|
GGML_ASSERT(n_pos_per_embd == 4);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
for (int y = 0; y < ny; y++) {
|
||||||
|
for (int x = 0; x < nx; x++) {
|
||||||
|
int i = y * nx + x;
|
||||||
|
pos[i ] = pos_0;
|
||||||
|
pos[i + batch.n_tokens ] = pos_0 + y;
|
||||||
|
pos[i + batch.n_tokens * 2] = pos_0 + x;
|
||||||
|
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// M-RoPE for audio
|
||||||
|
void set_position_mrope_1d(llama_pos pos_0, llama_seq_id seq_id) {
|
||||||
|
GGML_ASSERT(n_pos_per_embd == 4);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
pos[i ] = pos_0 + i;
|
||||||
|
pos[i + batch.n_tokens ] = pos_0 + i;
|
||||||
|
pos[i + batch.n_tokens * 2] = pos_0 + i;
|
||||||
|
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
|
||||||
|
}
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch get_view(int offset, int n_tokens) {
|
||||||
|
llama_pos * pos_ptr;
|
||||||
|
pos_view.clear();
|
||||||
|
pos_view.reserve(n_tokens * n_pos_per_embd);
|
||||||
|
if (n_pos_per_embd > 1) {
|
||||||
|
// mrope
|
||||||
|
// for example, with layout of src: 1234...1234...1234...1234...
|
||||||
|
// offset 2 will give us dst: 34...34...34...34...
|
||||||
|
for (int i = 0; i < n_pos_per_embd; i++) {
|
||||||
|
// assume n_tokens is less than or equal to batch.n_tokens
|
||||||
|
// batch.n_tokens is number of **total** tokens
|
||||||
|
// n_tokens is number of viewed token
|
||||||
|
size_t src_idx = i * batch.n_tokens + offset;
|
||||||
|
pos_view.insert(pos_view.end(),
|
||||||
|
pos.data() + src_idx,
|
||||||
|
pos.data() + src_idx + n_tokens);
|
||||||
|
}
|
||||||
|
pos_ptr = pos_view.data();
|
||||||
|
} else {
|
||||||
|
// normal
|
||||||
|
pos_ptr = pos.data() + offset;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ batch.embd + offset * n_mmproj_embd,
|
||||||
|
/*pos =*/ pos_ptr,
|
||||||
|
/*n_seq_id =*/ batch.n_seq_id + offset,
|
||||||
|
/*seq_id =*/ batch.seq_id + offset,
|
||||||
|
/*logits =*/ batch.logits + offset,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper function for decoding an image whose embeddings have already been calculated
|
||||||
|
int32_t mtmd_helper_decode_image_chunk(
|
||||||
|
mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunk * chunk,
|
||||||
|
float * encoded_embd,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
llama_pos * new_n_past) {
|
||||||
|
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||||
|
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||||
|
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_model * model = llama_get_model(lctx);
|
||||||
|
int n_mmproj_embd = llama_model_n_embd(model);
|
||||||
|
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
||||||
|
|
||||||
|
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
|
||||||
|
int32_t i_batch = 0;
|
||||||
|
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||||
|
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
||||||
|
|
||||||
|
if (mtmd_decode_use_mrope(ctx)) {
|
||||||
|
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||||
|
if (!image_tokens) {
|
||||||
|
LOG_ERR("failed to decode chunk: image tokens are null\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||||
|
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||||
|
batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id);
|
||||||
|
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
batch_embd.set_position_mrope_1d(n_past, seq_id);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid chunk type for M-RoPE");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
batch_embd.set_position_normal(n_past, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mtmd_decode_use_non_causal(ctx)) {
|
||||||
|
llama_set_causal_attn(lctx, false);
|
||||||
|
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||||
|
}
|
||||||
|
|
||||||
|
while (i_batch < n_img_batches) { // split into batches
|
||||||
|
int pos_offset = i_batch*n_batch;
|
||||||
|
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||||
|
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
||||||
|
|
||||||
|
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
int32_t ret = llama_decode(lctx, batch_embd_view);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to decode %s\n", name);
|
||||||
|
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||||
|
|
||||||
|
i_batch++;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||||
|
*new_n_past = n_past;
|
||||||
|
|
||||||
|
if (mtmd_decode_use_non_causal(ctx)) {
|
||||||
|
llama_set_causal_attn(lctx, true);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunk * chunk,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
bool logits_last,
|
||||||
|
llama_pos * new_n_past) {
|
||||||
|
int32_t ret;
|
||||||
|
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||||
|
|
||||||
|
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
size_t n_tokens;
|
||||||
|
const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||||
|
// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
|
||||||
|
size_t i = 0;
|
||||||
|
while (i < n_tokens) { // split into batches
|
||||||
|
text_batch.n_tokens = 0; // clear the batch
|
||||||
|
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
|
||||||
|
int32_t j = text_batch.n_tokens;
|
||||||
|
text_batch.token [j] = tokens[i];
|
||||||
|
text_batch.pos [j] = n_past++;
|
||||||
|
text_batch.n_seq_id[j] = 1;
|
||||||
|
text_batch.seq_id [j][0] = seq_id;
|
||||||
|
text_batch.logits [j] = false;
|
||||||
|
|
||||||
|
text_batch.n_tokens++;
|
||||||
|
}
|
||||||
|
bool is_last_token = (i == n_tokens);
|
||||||
|
if (logits_last && is_last_token) {
|
||||||
|
text_batch.logits[text_batch.n_tokens - 1] = true;
|
||||||
|
}
|
||||||
|
ret = llama_decode(lctx, text_batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to decode text\n");
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
*new_n_past += text_batch.n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
|
LOG_INF("encoding %s slice...\n", name);
|
||||||
|
|
||||||
|
ret = mtmd_encode_chunk(ctx, chunk);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to encode %s slice\n", name);
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||||
|
|
||||||
|
float * embd = mtmd_get_output_embd(ctx);
|
||||||
|
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to decode %s\n", name);
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("chunk type not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunks * chunks,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
bool logits_last,
|
||||||
|
llama_pos * new_n_past) {
|
||||||
|
size_t n_chunks = mtmd_input_chunks_size(chunks);
|
||||||
|
if (n_chunks == 0) {
|
||||||
|
LOG_ERR("no chunks to eval\n");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_chunks; i++) {
|
||||||
|
bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
|
||||||
|
auto chunk = mtmd_input_chunks_get(chunks, i);
|
||||||
|
|
||||||
|
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
|
||||||
|
if (res != 0) {
|
||||||
|
LOG_ERR("failed to eval chunk %zu\n", i);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
*new_n_past = n_past;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace audio_helpers {
|
||||||
|
|
||||||
|
static bool is_audio_file(const char * buf, size_t len) {
|
||||||
|
if (len < 12) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
|
||||||
|
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
|
||||||
|
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
|
||||||
|
bool is_mp3 = len >= 3 && (
|
||||||
|
memcmp(buf, "ID3", 3) == 0 ||
|
||||||
|
// Check for MPEG sync word (simplified check)
|
||||||
|
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
|
||||||
|
);
|
||||||
|
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
|
||||||
|
|
||||||
|
return is_wav || is_mp3 || is_flac;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true if the buffer is a valid audio file
|
||||||
|
static bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
|
||||||
|
ma_result result;
|
||||||
|
const int channels = 1;
|
||||||
|
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
|
||||||
|
ma_decoder decoder;
|
||||||
|
|
||||||
|
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ma_uint64 frame_count;
|
||||||
|
ma_uint64 frames_read;
|
||||||
|
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pcmf32_mono.resize(frame_count);
|
||||||
|
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef MTMD_AUDIO_DEBUG
|
||||||
|
// save audio to wav file
|
||||||
|
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
|
||||||
|
ma_encoder encoder;
|
||||||
|
ma_encoder_init_file("output.wav", &config, &encoder);
|
||||||
|
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
|
||||||
|
ma_encoder_uninit(&encoder);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace audio_helpers
|
||||||
|
|
||||||
|
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len) {
|
||||||
|
if (audio_helpers::is_audio_file((const char *)buf, len)) {
|
||||||
|
std::vector<float> pcmf32;
|
||||||
|
int bitrate = mtmd_get_audio_bitrate(ctx);
|
||||||
|
if (bitrate < 0) {
|
||||||
|
LOG_ERR("This model does not support audio input\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!audio_helpers::decode_audio_from_buf(buf, len, bitrate, pcmf32)) {
|
||||||
|
LOG_ERR("Unable to read WAV audio file from buffer\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, we assume it's an image
|
||||||
|
mtmd_bitmap * result = nullptr;
|
||||||
|
{
|
||||||
|
int nx, ny, nc;
|
||||||
|
auto * data = stbi_load_from_memory(buf, len, &nx, &ny, &nc, 3);
|
||||||
|
if (!data) {
|
||||||
|
LOG_ERR("%s: failed to decode image bytes\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
result = mtmd_bitmap_init(nx, ny, data);
|
||||||
|
stbi_image_free(data);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname) {
|
||||||
|
std::vector<unsigned char> buf;
|
||||||
|
FILE * f = fopen(fname, "rb");
|
||||||
|
if (!f) {
|
||||||
|
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
fseek(f, 0, SEEK_END);
|
||||||
|
long file_size = ftell(f);
|
||||||
|
fseek(f, 0, SEEK_SET);
|
||||||
|
buf.resize(file_size);
|
||||||
|
|
||||||
|
size_t n_read = fread(buf.data(), 1, file_size, f);
|
||||||
|
fclose(f);
|
||||||
|
if (n_read != (size_t)file_size) {
|
||||||
|
LOG_ERR("Failed to read entire file %s", fname);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
#ifndef MTMD_HELPER_H
|
||||||
|
#define MTMD_HELPER_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "llama.h"
|
||||||
|
#include "mtmd.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// libmtmd helper functions
|
||||||
|
//
|
||||||
|
// Please note that these helpers are not guaranteed to be stable.
|
||||||
|
// BREAKING CHANGES are expected.
|
||||||
|
//
|
||||||
|
|
||||||
|
// helper function to construct a mtmd_bitmap from a file
|
||||||
|
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||||
|
// returns nullptr on failure
|
||||||
|
// this function is thread-safe
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname);
|
||||||
|
|
||||||
|
// helper function to construct a mtmd_bitmap from a buffer containing a file
|
||||||
|
// supported formats:
|
||||||
|
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
|
||||||
|
// audio: formats supported by miniaudio: wav, mp3, flac
|
||||||
|
// note: audio files will be auto-detected based on magic bytes
|
||||||
|
// returns nullptr on failure
|
||||||
|
// this function is thread-safe
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
|
||||||
|
|
||||||
|
// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
|
||||||
|
MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
|
||||||
|
|
||||||
|
// helper to count the total position of tokens from a list of chunks, useful to keep track of n_past
|
||||||
|
// normally, n_pos is equal to n_tokens, but for M-RoPE it is different
|
||||||
|
MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
|
||||||
|
|
||||||
|
// helper function that automatically:
|
||||||
|
// 1. run llama_decode() on text chunks
|
||||||
|
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
|
||||||
|
// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
|
||||||
|
// otherwise, returns 0 on success
|
||||||
|
// this function is NOT thread-safe
|
||||||
|
MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunks * chunks,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
bool logits_last,
|
||||||
|
llama_pos * new_n_past);
|
||||||
|
|
||||||
|
// works like mtmd_helper_eval_chunks(), but only for a single chunk
|
||||||
|
// this function is NOT thread-safe
|
||||||
|
MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunk * chunk,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
bool logits_last,
|
||||||
|
llama_pos * new_n_past);
|
||||||
|
|
||||||
|
// helper function to decode an image whose embeddings have already been calculated
|
||||||
|
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
|
||||||
|
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
|
||||||
|
MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
|
||||||
|
struct llama_context * lctx,
|
||||||
|
const mtmd_input_chunk * chunk,
|
||||||
|
float * encoded_embd,
|
||||||
|
llama_pos n_past,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch,
|
||||||
|
llama_pos * new_n_past);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// C++ wrappers
|
||||||
|
//
|
||||||
|
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,8 @@
|
||||||
package mtmd
|
package mtmd
|
||||||
|
|
||||||
// #cgo CXXFLAGS: -std=c++11
|
// #cgo CXXFLAGS: -std=c++17
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../include
|
||||||
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../common
|
||||||
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../vendor
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../../ml/backend/ggml/ggml/include
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../../../ml/backend/ggml/ggml/include
|
||||||
import "C"
|
import "C"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,298 @@
|
||||||
|
#ifndef MTMD_H
|
||||||
|
#define MTMD_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <memory>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* libmtmd: A library for multimodal support in llama.cpp.
|
||||||
|
*
|
||||||
|
* WARNING: This API is experimental and subject to many BREAKING CHANGES.
|
||||||
|
* Issues related to API usage may receive lower priority support.
|
||||||
|
*
|
||||||
|
* For the usage, see an example in mtmd-cli.cpp
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifdef LLAMA_SHARED
|
||||||
|
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||||
|
# ifdef LLAMA_BUILD
|
||||||
|
# define MTMD_API __declspec(dllexport)
|
||||||
|
# else
|
||||||
|
# define MTMD_API __declspec(dllimport)
|
||||||
|
# endif
|
||||||
|
# else
|
||||||
|
# define MTMD_API __attribute__ ((visibility ("default")))
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# define MTMD_API
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// deprecated marker, use mtmd_default_marker() instead
|
||||||
|
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
enum mtmd_input_chunk_type {
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
||||||
|
};
|
||||||
|
|
||||||
|
// opaque types
|
||||||
|
struct mtmd_context;
|
||||||
|
struct mtmd_bitmap;
|
||||||
|
struct mtmd_image_tokens;
|
||||||
|
struct mtmd_input_chunk;
|
||||||
|
struct mtmd_input_chunks;
|
||||||
|
|
||||||
|
struct mtmd_input_text {
|
||||||
|
const char * text;
|
||||||
|
bool add_special;
|
||||||
|
bool parse_special;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// C API
|
||||||
|
//
|
||||||
|
|
||||||
|
typedef struct mtmd_context mtmd_context;
|
||||||
|
typedef struct mtmd_bitmap mtmd_bitmap;
|
||||||
|
typedef struct mtmd_image_tokens mtmd_image_tokens;
|
||||||
|
typedef struct mtmd_input_chunk mtmd_input_chunk;
|
||||||
|
typedef struct mtmd_input_chunks mtmd_input_chunks;
|
||||||
|
typedef struct mtmd_input_text mtmd_input_text;
|
||||||
|
|
||||||
|
struct mtmd_context_params {
|
||||||
|
bool use_gpu;
|
||||||
|
bool print_timings;
|
||||||
|
int n_threads;
|
||||||
|
enum ggml_log_level verbosity;
|
||||||
|
const char * image_marker; // deprecated, use media_marker instead
|
||||||
|
const char * media_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
MTMD_API const char * mtmd_default_marker(void);
|
||||||
|
|
||||||
|
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
||||||
|
|
||||||
|
// initialize the mtmd context
|
||||||
|
// return nullptr on failure
|
||||||
|
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||||
|
const struct llama_model * text_model,
|
||||||
|
const struct mtmd_context_params ctx_params);
|
||||||
|
|
||||||
|
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether we need to set non-causal mask before llama_decode
|
||||||
|
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether the current model use M-RoPE for llama_decode
|
||||||
|
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether the current model supports vision input
|
||||||
|
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether the current model supports audio input
|
||||||
|
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// get audio bitrate in Hz, for example 16000 for Whisper
|
||||||
|
// return -1 if audio is not supported
|
||||||
|
MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// mtmd_bitmap
|
||||||
|
//
|
||||||
|
// if bitmap is image:
|
||||||
|
// length of data must be nx * ny * 3
|
||||||
|
// the data is in RGBRGBRGB... format
|
||||||
|
// if bitmap is audio:
|
||||||
|
// length of data must be n_samples * sizeof(float)
|
||||||
|
// the data is in float format (PCM F32)
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
||||||
|
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||||
|
// bitmap ID is optional, but useful for KV cache tracking
|
||||||
|
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
||||||
|
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id);
|
||||||
|
|
||||||
|
|
||||||
|
// mtmd_input_chunks
|
||||||
|
//
|
||||||
|
// this is simply a list of mtmd_input_chunk
|
||||||
|
// the elements can only be populated via mtmd_tokenize()
|
||||||
|
MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void);
|
||||||
|
MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
|
||||||
|
MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx);
|
||||||
|
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
||||||
|
|
||||||
|
// mtmd_input_chunk
|
||||||
|
//
|
||||||
|
// the instance will be constructed via mtmd_tokenize()
|
||||||
|
// it will be freed along with mtmd_input_chunks
|
||||||
|
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
|
||||||
|
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
||||||
|
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
|
||||||
|
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
||||||
|
// returns nullptr for ID on text chunk
|
||||||
|
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
||||||
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||||
|
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
|
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
||||||
|
// you can move the chunk ownership to your own code by copying it
|
||||||
|
// remember to free the chunk when you are done with it
|
||||||
|
MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk);
|
||||||
|
MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
|
|
||||||
|
// mtmd_image_tokens
|
||||||
|
//
|
||||||
|
// the instance will be constructed via mtmd_tokenize()
|
||||||
|
// it will be freed along with mtmd_input_chunk
|
||||||
|
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
|
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||||
|
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||||
|
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||||
|
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
|
|
||||||
|
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||||
|
// the prompt must have the input image marker (default: "<__media__>") in it
|
||||||
|
// the default marker is defined by mtmd_default_marker()
|
||||||
|
// the marker will be replaced with the image/audio chunk
|
||||||
|
// for example:
|
||||||
|
// "here is an image: <__media__>\ndescribe it in detail."
|
||||||
|
// this will gives 3 chunks:
|
||||||
|
// 1. "here is an image: <start_of_image>"
|
||||||
|
// 2. (image/audio tokens)
|
||||||
|
// 3. "<end_of_image>\ndescribe it in detail."
|
||||||
|
// number of bitmaps must be equal to the number of markers in the prompt
|
||||||
|
// this function is thread-safe (shared ctx)
|
||||||
|
// return values:
|
||||||
|
// 0 on success
|
||||||
|
// 1 on number of bitmaps not matching the number of markers
|
||||||
|
// 2 on image preprocessing error
|
||||||
|
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||||
|
mtmd_input_chunks * output,
|
||||||
|
const mtmd_input_text * text,
|
||||||
|
const mtmd_bitmap ** bitmaps,
|
||||||
|
size_t n_bitmaps);
|
||||||
|
|
||||||
|
// returns 0 on success
|
||||||
|
// TODO: deprecate
|
||||||
|
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||||
|
const mtmd_image_tokens * image_tokens);
|
||||||
|
|
||||||
|
// returns 0 on success
|
||||||
|
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||||
|
const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
|
// get output embeddings from the last encode pass
|
||||||
|
// the reading size (in bytes) is equal to:
|
||||||
|
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||||
|
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||||
|
|
||||||
|
/////////////////////////////////////////
|
||||||
|
|
||||||
|
// test function, to be used in test-mtmd-c-api.c
|
||||||
|
MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// C++ wrappers
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
namespace mtmd {
|
||||||
|
|
||||||
|
struct mtmd_context_deleter {
|
||||||
|
void operator()(mtmd_context * val) { mtmd_free(val); }
|
||||||
|
};
|
||||||
|
using context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
|
||||||
|
|
||||||
|
struct mtmd_bitmap_deleter {
|
||||||
|
void operator()(mtmd_bitmap * val) { mtmd_bitmap_free(val); }
|
||||||
|
};
|
||||||
|
using bitmap_ptr = std::unique_ptr<mtmd_bitmap, mtmd_bitmap_deleter>;
|
||||||
|
|
||||||
|
struct mtmd_input_chunks_deleter {
|
||||||
|
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
|
||||||
|
};
|
||||||
|
using input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
|
||||||
|
|
||||||
|
struct mtmd_input_chunk_deleter {
|
||||||
|
void operator()(mtmd_input_chunk * val) { mtmd_input_chunk_free(val); }
|
||||||
|
};
|
||||||
|
using input_chunk_ptr = std::unique_ptr<mtmd_input_chunk, mtmd_input_chunk_deleter>;
|
||||||
|
|
||||||
|
struct bitmap {
|
||||||
|
bitmap_ptr ptr;
|
||||||
|
bitmap() : ptr(nullptr) {}
|
||||||
|
bitmap(mtmd_bitmap * bitmap) : ptr(bitmap) {}
|
||||||
|
bitmap(bitmap && other) noexcept : ptr(std::move(other.ptr)) {}
|
||||||
|
bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) {
|
||||||
|
ptr.reset(mtmd_bitmap_init(nx, ny, data));
|
||||||
|
}
|
||||||
|
~bitmap() = default;
|
||||||
|
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||||
|
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||||
|
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
|
||||||
|
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
|
||||||
|
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
|
||||||
|
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct bitmaps {
|
||||||
|
std::vector<bitmap> entries;
|
||||||
|
~bitmaps() = default;
|
||||||
|
// return list of pointers to mtmd_bitmap
|
||||||
|
// example:
|
||||||
|
// auto bitmaps_c_ptr = bitmaps.c_ptr();
|
||||||
|
// int32_t res = mtmd_tokenize(... bitmaps_c_ptr.data(), bitmaps_c_ptr.size());
|
||||||
|
std::vector<const mtmd_bitmap *> c_ptr() {
|
||||||
|
std::vector<const mtmd_bitmap *> res(entries.size());
|
||||||
|
for (size_t i = 0; i < entries.size(); i++) {
|
||||||
|
res[i] = entries[i].ptr.get();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct input_chunks {
|
||||||
|
input_chunks_ptr ptr;
|
||||||
|
input_chunks() = default;
|
||||||
|
input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {}
|
||||||
|
~input_chunks() = default;
|
||||||
|
size_t size() { return mtmd_input_chunks_size(ptr.get()); }
|
||||||
|
const mtmd_input_chunk * operator[](size_t idx) {
|
||||||
|
return mtmd_input_chunks_get(ptr.get(), idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mtmd
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,541 @@
|
||||||
|
/*
|
||||||
|
Copyright 2024 Google LLC
|
||||||
|
|
||||||
|
Use of this source code is governed by an MIT-style
|
||||||
|
license that can be found in the LICENSE file or at
|
||||||
|
https://opensource.org/licenses/MIT.
|
||||||
|
*/
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "minja.hpp"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ctime>
|
||||||
|
#include <exception>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
namespace minja {
|
||||||
|
|
||||||
|
struct chat_template_caps {
|
||||||
|
bool supports_tools = false;
|
||||||
|
bool supports_tool_calls = false;
|
||||||
|
bool supports_tool_responses = false;
|
||||||
|
bool supports_system_role = false;
|
||||||
|
bool supports_parallel_tool_calls = false;
|
||||||
|
bool supports_tool_call_id = false;
|
||||||
|
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
|
||||||
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
|
bool requires_object_arguments = false;
|
||||||
|
// CohereForAI/c4ai-command-r-plus simple variant
|
||||||
|
bool requires_non_null_content = false;
|
||||||
|
// MiniMaxAI/MiniMax-Text-01 special
|
||||||
|
bool requires_typed_content = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct chat_template_inputs {
|
||||||
|
nlohmann::ordered_json messages;
|
||||||
|
nlohmann::ordered_json tools;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
|
nlohmann::ordered_json extra_context;
|
||||||
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct chat_template_options {
|
||||||
|
bool apply_polyfills = true;
|
||||||
|
bool use_bos_token = true;
|
||||||
|
bool use_eos_token = true;
|
||||||
|
bool define_strftime_now = true;
|
||||||
|
|
||||||
|
bool polyfill_tools = true;
|
||||||
|
bool polyfill_tool_call_examples = true;
|
||||||
|
bool polyfill_tool_calls = true;
|
||||||
|
bool polyfill_tool_responses = true;
|
||||||
|
bool polyfill_system_role = true;
|
||||||
|
bool polyfill_object_arguments = true;
|
||||||
|
bool polyfill_typed_content = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
class chat_template {
|
||||||
|
|
||||||
|
private:
|
||||||
|
chat_template_caps caps_;
|
||||||
|
std::string source_;
|
||||||
|
std::string bos_token_;
|
||||||
|
std::string eos_token_;
|
||||||
|
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||||
|
std::string tool_call_example_;
|
||||||
|
|
||||||
|
std::string try_raw_render(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
chat_template_inputs inputs;
|
||||||
|
inputs.messages = messages;
|
||||||
|
inputs.tools = tools;
|
||||||
|
inputs.add_generation_prompt = add_generation_prompt;
|
||||||
|
inputs.extra_context = extra_context;
|
||||||
|
// Use fixed date for tests
|
||||||
|
inputs.now = std::chrono::system_clock::from_time_t(0);
|
||||||
|
|
||||||
|
chat_template_options opts;
|
||||||
|
opts.apply_polyfills = false;
|
||||||
|
|
||||||
|
auto prompt = apply(inputs, opts);
|
||||||
|
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
|
||||||
|
return prompt;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||||
|
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||||
|
{
|
||||||
|
template_root_ = minja::Parser::parse(source_, {
|
||||||
|
/* .trim_blocks = */ true,
|
||||||
|
/* .lstrip_blocks = */ true,
|
||||||
|
/* .keep_trailing_newline = */ false,
|
||||||
|
});
|
||||||
|
|
||||||
|
auto contains = [](const std::string & haystack, const std::string & needle) {
|
||||||
|
return haystack.find(needle) != std::string::npos;
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string user_needle = "<User Needle>";
|
||||||
|
const std::string sys_needle = "<System Needle>";
|
||||||
|
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
|
||||||
|
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
|
||||||
|
|
||||||
|
caps_.requires_typed_content =
|
||||||
|
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
|
||||||
|
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
|
||||||
|
|
||||||
|
const auto dummy_user_msg = caps_.requires_typed_content
|
||||||
|
? dummy_typed_user_msg
|
||||||
|
: dummy_str_user_msg;
|
||||||
|
const json needle_system_msg = {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
|
||||||
|
};
|
||||||
|
|
||||||
|
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
|
||||||
|
|
||||||
|
auto out = try_raw_render(json::array({
|
||||||
|
dummy_user_msg
|
||||||
|
}), json::array({
|
||||||
|
{
|
||||||
|
{"name", "some_tool"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "some_tool"},
|
||||||
|
{"description", "Some tool."},
|
||||||
|
{"parameters", {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"arg", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"description", "Some argument."},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({ "arg" })},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}), false);
|
||||||
|
caps_.supports_tools = contains(out, "some_tool");
|
||||||
|
|
||||||
|
auto make_tool_calls_msg = [&](const json & tool_calls) {
|
||||||
|
return json {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", nullptr},
|
||||||
|
{"tool_calls", tool_calls},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
|
||||||
|
return json {
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"arguments", arguments},
|
||||||
|
{"name", tool_name},
|
||||||
|
}},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
||||||
|
|
||||||
|
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
||||||
|
out = try_raw_render(json::array({
|
||||||
|
dummy_user_msg,
|
||||||
|
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
||||||
|
}), {}, false);
|
||||||
|
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||||
|
out = try_raw_render(json::array({
|
||||||
|
dummy_user_msg,
|
||||||
|
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
||||||
|
}), {}, false);
|
||||||
|
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||||
|
|
||||||
|
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
||||||
|
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
||||||
|
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
|
||||||
|
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
|
||||||
|
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
|
||||||
|
|
||||||
|
if (caps_.supports_tool_calls) {
|
||||||
|
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
|
||||||
|
auto tc1 = make_tool_call("test_tool1", dummy_args);
|
||||||
|
auto tc2 = make_tool_call("test_tool2", dummy_args);
|
||||||
|
auto out = try_raw_render(json::array({
|
||||||
|
dummy_user_msg,
|
||||||
|
make_tool_calls_msg(json::array({tc1, tc2})),
|
||||||
|
}), {}, false);
|
||||||
|
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
|
||||||
|
|
||||||
|
out = try_raw_render(json::array({
|
||||||
|
dummy_user_msg,
|
||||||
|
make_tool_calls_msg(json::array({tc1})),
|
||||||
|
{
|
||||||
|
{"role", "tool"},
|
||||||
|
{"name", "test_tool1"},
|
||||||
|
{"content", "Some response!"},
|
||||||
|
{"tool_call_id", "call_911_"},
|
||||||
|
}
|
||||||
|
}), {}, false);
|
||||||
|
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||||
|
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!caps_.supports_tools) {
|
||||||
|
const json user_msg {
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hey"},
|
||||||
|
};
|
||||||
|
const json args {
|
||||||
|
{"arg1", "some_value"},
|
||||||
|
};
|
||||||
|
const json tool_call_msg {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", nullptr},
|
||||||
|
{"tool_calls", json::array({
|
||||||
|
{
|
||||||
|
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "tool_name"},
|
||||||
|
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})},
|
||||||
|
};
|
||||||
|
std::string prefix, full;
|
||||||
|
{
|
||||||
|
chat_template_inputs inputs;
|
||||||
|
inputs.messages = json::array({user_msg});
|
||||||
|
inputs.add_generation_prompt = true;
|
||||||
|
prefix = apply(inputs);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
chat_template_inputs inputs;
|
||||||
|
inputs.messages = json::array({user_msg, tool_call_msg});
|
||||||
|
inputs.add_generation_prompt = false;
|
||||||
|
full = apply(inputs);
|
||||||
|
}
|
||||||
|
auto eos_pos_last = full.rfind(eos_token_);
|
||||||
|
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||||
|
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||||
|
full = full.substr(0, eos_pos_last);
|
||||||
|
}
|
||||||
|
size_t common_prefix_length = 0;
|
||||||
|
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||||
|
if (prefix[i] != full[i]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (prefix[i] == '<') {
|
||||||
|
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||||
|
// but it removes thinking tags for past messages.
|
||||||
|
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
common_prefix_length = i + 1;
|
||||||
|
}
|
||||||
|
auto example = full.substr(common_prefix_length);
|
||||||
|
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||||
|
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||||
|
} else {
|
||||||
|
tool_call_example_ = example;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string & source() const { return source_; }
|
||||||
|
const std::string & bos_token() const { return bos_token_; }
|
||||||
|
const std::string & eos_token() const { return eos_token_; }
|
||||||
|
const chat_template_caps & original_caps() const { return caps_; }
|
||||||
|
|
||||||
|
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
||||||
|
std::string apply(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||||
|
bool apply_polyfills = true)
|
||||||
|
{
|
||||||
|
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
||||||
|
chat_template_inputs inputs;
|
||||||
|
inputs.messages = messages;
|
||||||
|
inputs.tools = tools;
|
||||||
|
inputs.add_generation_prompt = add_generation_prompt;
|
||||||
|
inputs.extra_context = extra_context;
|
||||||
|
inputs.now = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
chat_template_options opts;
|
||||||
|
opts.apply_polyfills = apply_polyfills;
|
||||||
|
|
||||||
|
return apply(inputs, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string apply(
|
||||||
|
const chat_template_inputs & inputs,
|
||||||
|
const chat_template_options & opts = chat_template_options()) const
|
||||||
|
{
|
||||||
|
json actual_messages;
|
||||||
|
|
||||||
|
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||||
|
auto has_tool_calls = false;
|
||||||
|
auto has_tool_responses = false;
|
||||||
|
auto has_string_content = false;
|
||||||
|
for (const auto & message : inputs.messages) {
|
||||||
|
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
|
||||||
|
has_tool_calls = true;
|
||||||
|
}
|
||||||
|
if (message.contains("role") && message["role"] == "tool") {
|
||||||
|
has_tool_responses = true;
|
||||||
|
}
|
||||||
|
if (message.contains("content") && message["content"].is_string()) {
|
||||||
|
has_string_content = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
|
||||||
|
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
|
||||||
|
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
|
||||||
|
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
|
||||||
|
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
|
||||||
|
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
|
||||||
|
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
|
||||||
|
|
||||||
|
auto needs_polyfills = opts.apply_polyfills && (false
|
||||||
|
|| polyfill_system_role
|
||||||
|
|| polyfill_tools
|
||||||
|
|| polyfill_tool_calls
|
||||||
|
|| polyfill_tool_responses
|
||||||
|
|| polyfill_object_arguments
|
||||||
|
|| polyfill_typed_content
|
||||||
|
);
|
||||||
|
|
||||||
|
if (needs_polyfills) {
|
||||||
|
actual_messages = json::array();
|
||||||
|
|
||||||
|
auto add_message = [&](const json & msg) {
|
||||||
|
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
||||||
|
actual_messages.push_back({
|
||||||
|
{"role", msg.at("role")},
|
||||||
|
{"content", {{
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", msg.at("content")},
|
||||||
|
}}},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
actual_messages.push_back(msg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string pending_system;
|
||||||
|
auto flush_sys = [&]() {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
add_message({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", pending_system},
|
||||||
|
});
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
json adjusted_messages;
|
||||||
|
if (polyfill_tools) {
|
||||||
|
adjusted_messages = add_system(inputs.messages,
|
||||||
|
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||||
|
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||||
|
} else {
|
||||||
|
adjusted_messages = inputs.messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & message_ : adjusted_messages) {
|
||||||
|
auto message = message_;
|
||||||
|
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||||
|
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||||
|
}
|
||||||
|
std::string role = message.at("role");
|
||||||
|
|
||||||
|
if (message.contains("tool_calls")) {
|
||||||
|
if (polyfill_object_arguments || polyfill_tool_calls) {
|
||||||
|
for (auto & tool_call : message.at("tool_calls")) {
|
||||||
|
if (tool_call["type"] == "function") {
|
||||||
|
auto & function = tool_call.at("function");
|
||||||
|
auto & arguments = function.at("arguments");
|
||||||
|
if (arguments.is_string()) {
|
||||||
|
try {
|
||||||
|
arguments = json::parse(arguments.get<std::string>());
|
||||||
|
} catch (const std::exception & ecvt) {
|
||||||
|
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (polyfill_tool_calls) {
|
||||||
|
auto tool_calls = json::array();
|
||||||
|
for (const auto & tool_call : message.at("tool_calls")) {
|
||||||
|
if (tool_call.at("type") != "function") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool_call.at("function");
|
||||||
|
auto tc = json {
|
||||||
|
{"name", function.at("name")},
|
||||||
|
{"arguments", function.at("arguments")},
|
||||||
|
};
|
||||||
|
if (tool_call.contains("id")) {
|
||||||
|
tc["id"] = tool_call["id"];
|
||||||
|
}
|
||||||
|
tool_calls.push_back(tc);
|
||||||
|
}
|
||||||
|
auto obj = json {
|
||||||
|
{"tool_calls", tool_calls},
|
||||||
|
};
|
||||||
|
if (message.contains("content")) {
|
||||||
|
auto content = message.at("content");
|
||||||
|
if (!content.is_null() && !content.empty()) {
|
||||||
|
obj["content"] = content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message["content"] = obj.dump(2);
|
||||||
|
message.erase("tool_calls");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (polyfill_tool_responses && role == "tool") {
|
||||||
|
message["role"] = "user";
|
||||||
|
auto obj = json {
|
||||||
|
{"tool_response", json::object()},
|
||||||
|
};
|
||||||
|
if (message.contains("name")) {
|
||||||
|
obj["tool_response"]["tool"] = message.at("name");
|
||||||
|
}
|
||||||
|
obj["tool_response"]["content"] = message.at("content");
|
||||||
|
if (message.contains("tool_call_id")) {
|
||||||
|
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||||
|
}
|
||||||
|
message["content"] = obj.dump(2);
|
||||||
|
message.erase("name");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!message["content"].is_null() && polyfill_system_role) {
|
||||||
|
std::string content = message.at("content");
|
||||||
|
if (role == "system") {
|
||||||
|
if (!pending_system.empty()) pending_system += "\n";
|
||||||
|
pending_system += content;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (role == "user") {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
flush_sys();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
add_message(message);
|
||||||
|
}
|
||||||
|
flush_sys();
|
||||||
|
} else {
|
||||||
|
actual_messages = inputs.messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = minja::Context::make(json({
|
||||||
|
{"messages", actual_messages},
|
||||||
|
{"add_generation_prompt", inputs.add_generation_prompt},
|
||||||
|
}));
|
||||||
|
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
|
||||||
|
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
|
||||||
|
if (opts.define_strftime_now) {
|
||||||
|
auto now = inputs.now;
|
||||||
|
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
|
||||||
|
args.expectArgs("strftime_now", {1, 1}, {0, 0});
|
||||||
|
auto format = args.args[0].get<std::string>();
|
||||||
|
|
||||||
|
auto time = std::chrono::system_clock::to_time_t(now);
|
||||||
|
auto local_time = *std::localtime(&time);
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << std::put_time(&local_time, format.c_str());
|
||||||
|
return ss.str();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
if (!inputs.tools.is_null()) {
|
||||||
|
context->set("tools", minja::Value(inputs.tools));
|
||||||
|
}
|
||||||
|
if (!inputs.extra_context.is_null()) {
|
||||||
|
for (auto & kv : inputs.extra_context.items()) {
|
||||||
|
context->set(kv.key(), minja::Value(kv.value()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = template_root_->render(context);
|
||||||
|
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
|
||||||
|
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||||
|
json messages_with_system = messages;
|
||||||
|
|
||||||
|
if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") {
|
||||||
|
std::string existing_system = messages_with_system.at(0).at("content");
|
||||||
|
messages_with_system[0] = json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", existing_system + "\n\n" + system_prompt},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
messages_with_system.insert(messages_with_system.begin(), json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return messages_with_system;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace minja
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,187 @@
|
||||||
|
// __ _____ _____ _____
|
||||||
|
// __| | __| | | | JSON for Modern C++
|
||||||
|
// | | |__ | | | | | | version 3.12.0
|
||||||
|
// |_____|_____|_____|_|___| https://github.com/nlohmann/json
|
||||||
|
//
|
||||||
|
// SPDX-FileCopyrightText: 2013 - 2025 Niels Lohmann <https://nlohmann.me>
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_
|
||||||
|
#define INCLUDE_NLOHMANN_JSON_FWD_HPP_
|
||||||
|
|
||||||
|
#include <cstdint> // int64_t, uint64_t
|
||||||
|
#include <map> // map
|
||||||
|
#include <memory> // allocator
|
||||||
|
#include <string> // string
|
||||||
|
#include <vector> // vector
|
||||||
|
|
||||||
|
// #include <nlohmann/detail/abi_macros.hpp>
|
||||||
|
// __ _____ _____ _____
|
||||||
|
// __| | __| | | | JSON for Modern C++
|
||||||
|
// | | |__ | | | | | | version 3.12.0
|
||||||
|
// |_____|_____|_____|_|___| https://github.com/nlohmann/json
|
||||||
|
//
|
||||||
|
// SPDX-FileCopyrightText: 2013 - 2025 Niels Lohmann <https://nlohmann.me>
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// This file contains all macro definitions affecting or depending on the ABI
|
||||||
|
|
||||||
|
#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK
|
||||||
|
#if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)
|
||||||
|
#if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 12 || NLOHMANN_JSON_VERSION_PATCH != 0
|
||||||
|
#warning "Already included a different version of the library!"
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum)
|
||||||
|
#define NLOHMANN_JSON_VERSION_MINOR 12 // NOLINT(modernize-macro-to-enum)
|
||||||
|
#define NLOHMANN_JSON_VERSION_PATCH 0 // NOLINT(modernize-macro-to-enum)
|
||||||
|
|
||||||
|
#ifndef JSON_DIAGNOSTICS
|
||||||
|
#define JSON_DIAGNOSTICS 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef JSON_DIAGNOSTIC_POSITIONS
|
||||||
|
#define JSON_DIAGNOSTIC_POSITIONS 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
|
||||||
|
#define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if JSON_DIAGNOSTICS
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag
|
||||||
|
#else
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if JSON_DIAGNOSTIC_POSITIONS
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS _dp
|
||||||
|
#else
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp
|
||||||
|
#else
|
||||||
|
#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Construct the namespace ABI tags component
|
||||||
|
#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b, c) json_abi ## a ## b ## c
|
||||||
|
#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b, c) \
|
||||||
|
NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b, c)
|
||||||
|
|
||||||
|
#define NLOHMANN_JSON_ABI_TAGS \
|
||||||
|
NLOHMANN_JSON_ABI_TAGS_CONCAT( \
|
||||||
|
NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \
|
||||||
|
NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON, \
|
||||||
|
NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS)
|
||||||
|
|
||||||
|
// Construct the namespace version component
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \
|
||||||
|
_v ## major ## _ ## minor ## _ ## patch
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \
|
||||||
|
NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)
|
||||||
|
|
||||||
|
#if NLOHMANN_JSON_NAMESPACE_NO_VERSION
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_VERSION
|
||||||
|
#else
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_VERSION \
|
||||||
|
NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \
|
||||||
|
NLOHMANN_JSON_VERSION_MINOR, \
|
||||||
|
NLOHMANN_JSON_VERSION_PATCH)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Combine namespace components
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \
|
||||||
|
NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE \
|
||||||
|
nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \
|
||||||
|
NLOHMANN_JSON_ABI_TAGS, \
|
||||||
|
NLOHMANN_JSON_NAMESPACE_VERSION)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_BEGIN \
|
||||||
|
namespace nlohmann \
|
||||||
|
{ \
|
||||||
|
inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \
|
||||||
|
NLOHMANN_JSON_ABI_TAGS, \
|
||||||
|
NLOHMANN_JSON_NAMESPACE_VERSION) \
|
||||||
|
{
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE_END
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_END \
|
||||||
|
} /* namespace (inline namespace) NOLINT(readability/namespace) */ \
|
||||||
|
} // namespace nlohmann
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
/*!
|
||||||
|
@brief namespace for Niels Lohmann
|
||||||
|
@see https://github.com/nlohmann
|
||||||
|
@since version 1.0.0
|
||||||
|
*/
|
||||||
|
NLOHMANN_JSON_NAMESPACE_BEGIN
|
||||||
|
|
||||||
|
/*!
|
||||||
|
@brief default JSONSerializer template argument
|
||||||
|
|
||||||
|
This serializer ignores the template arguments and uses ADL
|
||||||
|
([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl))
|
||||||
|
for serialization.
|
||||||
|
*/
|
||||||
|
template<typename T = void, typename SFINAE = void>
|
||||||
|
struct adl_serializer;
|
||||||
|
|
||||||
|
/// a class to store JSON values
|
||||||
|
/// @sa https://json.nlohmann.me/api/basic_json/
|
||||||
|
template<template<typename U, typename V, typename... Args> class ObjectType =
|
||||||
|
std::map,
|
||||||
|
template<typename U, typename... Args> class ArrayType = std::vector,
|
||||||
|
class StringType = std::string, class BooleanType = bool,
|
||||||
|
class NumberIntegerType = std::int64_t,
|
||||||
|
class NumberUnsignedType = std::uint64_t,
|
||||||
|
class NumberFloatType = double,
|
||||||
|
template<typename U> class AllocatorType = std::allocator,
|
||||||
|
template<typename T, typename SFINAE = void> class JSONSerializer =
|
||||||
|
adl_serializer,
|
||||||
|
class BinaryType = std::vector<std::uint8_t>, // cppcheck-suppress syntaxError
|
||||||
|
class CustomBaseClass = void>
|
||||||
|
class basic_json;
|
||||||
|
|
||||||
|
/// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document
|
||||||
|
/// @sa https://json.nlohmann.me/api/json_pointer/
|
||||||
|
template<typename RefStringType>
|
||||||
|
class json_pointer;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
@brief default specialization
|
||||||
|
@sa https://json.nlohmann.me/api/json/
|
||||||
|
*/
|
||||||
|
using json = basic_json<>;
|
||||||
|
|
||||||
|
/// @brief a minimal map-like container that preserves insertion order
|
||||||
|
/// @sa https://json.nlohmann.me/api/ordered_map/
|
||||||
|
template<class Key, class T, class IgnoredLess, class Allocator>
|
||||||
|
struct ordered_map;
|
||||||
|
|
||||||
|
/// @brief specialization that maintains the insertion order of object keys
|
||||||
|
/// @sa https://json.nlohmann.me/api/ordered_json/
|
||||||
|
using ordered_json = basic_json<nlohmann::ordered_map>;
|
||||||
|
|
||||||
|
NLOHMANN_JSON_NAMESPACE_END
|
||||||
|
|
||||||
|
#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_
|
||||||
105
llama/llama.go
105
llama/llama.go
|
|
@ -7,6 +7,7 @@ package llama
|
||||||
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
|
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
|
||||||
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
|
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
|
||||||
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/tools/mtmd
|
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/tools/mtmd
|
||||||
|
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/vendor
|
||||||
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
|
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
|
||||||
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
|
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
|
||||||
|
|
||||||
|
|
@ -14,7 +15,6 @@ package llama
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "llava.h"
|
|
||||||
#include "gguf.h"
|
#include "gguf.h"
|
||||||
|
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
|
|
@ -159,14 +159,6 @@ func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
|
||||||
C.llama_kv_self_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
|
C.llama_kv_self_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) KvCacheClear() {
|
|
||||||
C.llama_kv_self_clear(c.c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) KvCacheDefrag() {
|
|
||||||
C.llama_kv_self_defrag(c.c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) KvCacheCanShift() bool {
|
func (c *Context) KvCacheCanShift() bool {
|
||||||
return bool(C.llama_kv_self_can_shift(c.c))
|
return bool(C.llama_kv_self_can_shift(c.c))
|
||||||
}
|
}
|
||||||
|
|
@ -467,18 +459,36 @@ type ClipContext struct {
|
||||||
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
|
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
|
||||||
mp := C.CString(modelPath)
|
mp := C.CString(modelPath)
|
||||||
defer C.free(unsafe.Pointer(mp))
|
defer C.free(unsafe.Pointer(mp))
|
||||||
c := C.clip_model_load(mp, 1)
|
|
||||||
if c == nil {
|
// Set up clip context parameters
|
||||||
return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
|
params := C.struct_clip_context_params{
|
||||||
|
use_gpu: true,
|
||||||
|
verbosity: C.GGML_LOG_LEVEL_INFO,
|
||||||
}
|
}
|
||||||
|
|
||||||
projEmbedSize := int(C.clip_n_mmproj_embd(c))
|
// Initialize clip contexts (returns both vision and audio)
|
||||||
|
result := C.clip_init(mp, params)
|
||||||
|
if result.ctx_v == nil {
|
||||||
|
if result.ctx_a != nil {
|
||||||
|
C.clip_free(result.ctx_a)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unable to load vision model: %v", modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free audio context if we don't need it
|
||||||
|
if result.ctx_a != nil {
|
||||||
|
C.clip_free(result.ctx_a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate embedding sizes
|
||||||
|
projEmbedSize := int(C.clip_n_mmproj_embd(result.ctx_v))
|
||||||
modelEmbedSize := llamaContext.Model().NEmbd()
|
modelEmbedSize := llamaContext.Model().NEmbd()
|
||||||
if projEmbedSize != modelEmbedSize {
|
if projEmbedSize != modelEmbedSize {
|
||||||
|
C.clip_free(result.ctx_v)
|
||||||
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
|
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ClipContext{c: c}, nil
|
return &ClipContext{c: result.ctx_v}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClipContext) Free() {
|
func (c *ClipContext) Free() {
|
||||||
|
|
@ -486,25 +496,66 @@ func (c *ClipContext) Free() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
|
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
|
||||||
l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
|
// Step 1: Load image from bytes (same as before)
|
||||||
if l == nil {
|
img := C.clip_image_u8_init()
|
||||||
return nil, errors.New("unable to make llava embedding from image")
|
if img == nil {
|
||||||
|
return nil, errors.New("failed to initialize image")
|
||||||
|
}
|
||||||
|
defer C.clip_image_u8_free(img)
|
||||||
|
|
||||||
|
ok := C.clip_image_load_from_bytes(
|
||||||
|
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||||
|
C.size_t(len(data)),
|
||||||
|
img,
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("failed to load image from bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Preprocess image
|
||||||
|
batch := C.clip_image_f32_batch_init()
|
||||||
|
if batch == nil {
|
||||||
|
return nil, errors.New("failed to initialize image batch")
|
||||||
|
}
|
||||||
|
defer C.clip_image_f32_batch_free(batch)
|
||||||
|
|
||||||
|
ok = C.clip_image_preprocess(c.c, img, batch)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("failed to preprocess image")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Calculate total tokens and allocate memory
|
||||||
|
nImages := C.clip_image_f32_batch_n_images(batch)
|
||||||
|
totalTokens := 0
|
||||||
|
for i := C.size_t(0); i < nImages; i++ {
|
||||||
|
imgF32 := C.clip_image_f32_get_img(batch, C.int(i))
|
||||||
|
tokens := int(C.clip_n_output_tokens(c.c, imgF32))
|
||||||
|
totalTokens += tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalTokens == 0 {
|
||||||
|
return nil, errors.New("no tokens generated from image")
|
||||||
}
|
}
|
||||||
|
|
||||||
numTokens := int(l.n_image_pos)
|
|
||||||
numEmbed := llamaContext.Model().NEmbd()
|
numEmbed := llamaContext.Model().NEmbd()
|
||||||
|
imageEmbd := make([]float32, totalTokens*numEmbed)
|
||||||
|
|
||||||
s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
|
// Step 4: Encode the image batch
|
||||||
|
ok = C.clip_image_batch_encode(
|
||||||
embed := make([][]float32, numTokens)
|
c.c,
|
||||||
rows := make([]float32, len(s))
|
C.int(llamaContext.numThreads),
|
||||||
copy(rows, s)
|
batch,
|
||||||
|
(*C.float)(unsafe.Pointer(&imageEmbd[0])),
|
||||||
for i := range embed {
|
)
|
||||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
if !ok {
|
||||||
|
return nil, errors.New("failed to encode image")
|
||||||
}
|
}
|
||||||
|
|
||||||
C.llava_image_embed_free(l)
|
// Step 5: Convert to slice of slices format
|
||||||
|
embed := make([][]float32, totalTokens)
|
||||||
|
for i := range embed {
|
||||||
|
embed[i] = imageEmbd[i*numEmbed : (i+1)*numEmbed]
|
||||||
|
}
|
||||||
|
|
||||||
return embed, nil
|
return embed, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ problem.
|
||||||
9 files changed, 21 insertions(+), 2 deletions(-)
|
9 files changed, 21 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
index b30b4cb3..0ce73a99 100644
|
index b1050ad5..e8694e5c 100644
|
||||||
--- a/ggml/src/ggml-backend.cpp
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
+++ b/ggml/src/ggml-backend.cpp
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
@@ -107,7 +107,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
@@ -107,7 +107,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
||||||
|
|
@ -43,7 +43,7 @@ index b30b4cb3..0ce73a99 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
@@ -1871,6 +1871,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
@@ -1879,6 +1879,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
|
||||||
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_aligned_free(buffer->context, buffer->size);
|
ggml_aligned_free(buffer->context, buffer->size);
|
||||||
|
|
@ -55,7 +55,7 @@ index b30b4cb3..0ce73a99 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
@@ -1918,7 +1923,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
|
@@ -1926,7 +1931,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
|
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
|
||||||
|
|
@ -65,10 +65,10 @@ index b30b4cb3..0ce73a99 100644
|
||||||
/* .init_tensor = */ NULL, // no initialization required
|
/* .init_tensor = */ NULL, // no initialization required
|
||||||
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
||||||
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
|
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
index e2617b06..242e50a7 100644
|
index c0ea2600..6c3398da 100755
|
||||||
--- a/ggml/src/ggml-cann/ggml-cann.cpp
|
--- a/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
|
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
@@ -800,6 +800,7 @@ static void ggml_backend_cann_buffer_free_buffer(
|
@@ -801,6 +801,7 @@ static void ggml_backend_cann_buffer_free_buffer(
|
||||||
ggml_backend_cann_buffer_context* ctx =
|
ggml_backend_cann_buffer_context* ctx =
|
||||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -76,7 +76,7 @@ index e2617b06..242e50a7 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -1472,6 +1473,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
|
@@ -1473,6 +1474,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
|
||||||
*/
|
*/
|
||||||
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
||||||
ACL_CHECK(aclrtFreeHost(buffer->context));
|
ACL_CHECK(aclrtFreeHost(buffer->context));
|
||||||
|
|
@ -85,7 +85,7 @@ index e2617b06..242e50a7 100644
|
||||||
|
|
||||||
/**
|
/**
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index b4b85abc..cb0d8528 100644
|
index 2a6f7f10..ec031650 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -534,6 +534,7 @@ struct ggml_backend_cuda_buffer_context {
|
@@ -534,6 +534,7 @@ struct ggml_backend_cuda_buffer_context {
|
||||||
|
|
@ -104,7 +104,7 @@ index b4b85abc..cb0d8528 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -1067,6 +1069,7 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
|
@@ -1071,6 +1073,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
|
||||||
|
|
||||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
CUDA_CHECK(cudaFreeHost(buffer->context));
|
||||||
|
|
@ -125,10 +125,10 @@ index 50579227..2799a0a5 100644
|
||||||
|
|
||||||
static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
index 576f9581..1b56f858 100644
|
index bc93bc63..fd3a9d1b 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
@@ -5214,6 +5214,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
@@ -5272,6 +5272,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
free(ctx);
|
free(ctx);
|
||||||
|
|
@ -137,10 +137,10 @@ index 576f9581..1b56f858 100644
|
||||||
|
|
||||||
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
index 05a2f4e6..392cc18d 100644
|
index 80a36438..6abb0ab2 100644
|
||||||
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
|
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
@@ -1940,6 +1940,7 @@ struct ggml_backend_opencl_buffer_context {
|
@@ -2366,6 +2366,7 @@ struct ggml_backend_opencl_buffer_context {
|
||||||
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -161,10 +161,10 @@ index 4f0abb5a..de1ec184 100644
|
||||||
|
|
||||||
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
index 0ea72994..ae3a3c33 100644
|
index 78513114..0dabdfe7 100644
|
||||||
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
|
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
@@ -320,6 +320,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
|
@@ -331,6 +331,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
|
||||||
ggml_sycl_set_device(ctx->device);
|
ggml_sycl_set_device(ctx->device);
|
||||||
|
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -172,7 +172,7 @@ index 0ea72994..ae3a3c33 100644
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||||
@@ -765,6 +766,7 @@ struct ggml_backend_sycl_split_buffer_context {
|
@@ -791,6 +792,7 @@ struct ggml_backend_sycl_split_buffer_context {
|
||||||
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
|
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -180,7 +180,7 @@ index 0ea72994..ae3a3c33 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -1099,6 +1101,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
@@ -1133,6 +1135,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
||||||
|
|
||||||
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_sycl_host_free(buffer->context);
|
ggml_sycl_host_free(buffer->context);
|
||||||
|
|
@ -189,10 +189,10 @@ index 0ea72994..ae3a3c33 100644
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
index e2b357fd..68768029 100644
|
index 3e43b03b..01776f3d 100644
|
||||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
@@ -8962,6 +8962,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
@@ -9272,6 +9272,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||||
ggml_vk_destroy_buffer(ctx->dev_buffer);
|
ggml_vk_destroy_buffer(ctx->dev_buffer);
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -200,7 +200,7 @@ index e2b357fd..68768029 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -9105,6 +9106,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
|
@@ -9415,6 +9416,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
|
||||||
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
|
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
|
||||||
ggml_vk_host_free(vk_instance.devices[0], buffer->context);
|
ggml_vk_host_free(vk_instance.devices[0], buffer->context);
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ logs instead of throwing an error
|
||||||
1 file changed, 3 insertions(+), 11 deletions(-)
|
1 file changed, 3 insertions(+), 11 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 9389ca80..806c1b3d 100644
|
index ba2e1864..0d7ad157 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -1503,16 +1503,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
@@ -1503,16 +1503,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,10 @@ instead of forcing one or the error
|
||||||
1 file changed, 3 insertions(+), 3 deletions(-)
|
1 file changed, 3 insertions(+), 3 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||||
index 62246c10..dca22d8b 100644
|
index c29fe7e4..148d1132 100644
|
||||||
--- a/src/llama-context.cpp
|
--- a/src/llama-context.cpp
|
||||||
+++ b/src/llama-context.cpp
|
+++ b/src/llama-context.cpp
|
||||||
@@ -901,7 +901,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
@@ -952,7 +952,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
int64_t n_outputs_all = 0;
|
int64_t n_outputs_all = 0;
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
|
|
@ -23,7 +23,7 @@ index 62246c10..dca22d8b 100644
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs_all += batch.logits[i] != 0;
|
n_outputs_all += batch.logits[i] != 0;
|
||||||
}
|
}
|
||||||
@@ -982,7 +982,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
@@ -1083,7 +1083,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
|
@ -32,7 +32,7 @@ index 62246c10..dca22d8b 100644
|
||||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||||
|
|
||||||
if (t_embd && res->get_embd_pooled()) {
|
if (t_embd && res->get_embd_pooled()) {
|
||||||
@@ -1151,7 +1151,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
@@ -1244,7 +1244,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
// TODO: use a per-batch flag for logits presence instead
|
// TODO: use a per-batch flag for logits presence instead
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ filesystems for paths that include wide characters
|
||||||
1 file changed, 39 insertions(+)
|
1 file changed, 39 insertions(+)
|
||||||
|
|
||||||
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
|
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
|
||||||
index 41ba45a7..cdd8ca44 100644
|
index c25bacc1..b3f92814 100644
|
||||||
--- a/tools/mtmd/clip.cpp
|
--- a/tools/mtmd/clip.cpp
|
||||||
+++ b/tools/mtmd/clip.cpp
|
+++ b/tools/mtmd/clip.cpp
|
||||||
@@ -31,6 +31,19 @@
|
@@ -28,6 +28,19 @@
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
|
@ -33,7 +33,7 @@ index 41ba45a7..cdd8ca44 100644
|
||||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||||
|
|
||||||
enum ffn_op_type {
|
enum ffn_op_type {
|
||||||
@@ -2190,7 +2203,29 @@ struct clip_model_loader {
|
@@ -2552,7 +2565,29 @@ struct clip_model_loader {
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> read_buf;
|
std::vector<uint8_t> read_buf;
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ index 41ba45a7..cdd8ca44 100644
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
||||||
}
|
}
|
||||||
@@ -2217,7 +2252,11 @@ struct clip_model_loader {
|
@@ -2579,7 +2614,11 @@ struct clip_model_loader {
|
||||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ adds support for the Solar Pro architecture
|
||||||
7 files changed, 248 insertions(+)
|
7 files changed, 248 insertions(+)
|
||||||
|
|
||||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||||
index f2bc8ca7..5ab3f572 100644
|
index c0590e10..6d9f0719 100644
|
||||||
--- a/src/llama-arch.cpp
|
--- a/src/llama-arch.cpp
|
||||||
+++ b/src/llama-arch.cpp
|
+++ b/src/llama-arch.cpp
|
||||||
@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
|
|
@ -34,7 +34,7 @@ index f2bc8ca7..5ab3f572 100644
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
|
||||||
@@ -1502,6 +1504,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
@@ -1508,6 +1510,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -59,7 +59,7 @@ index f2bc8ca7..5ab3f572 100644
|
||||||
{
|
{
|
||||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||||
{
|
{
|
||||||
@@ -1680,6 +1700,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
@@ -1686,6 +1706,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
// this tensor is loaded for T5, but never used
|
// this tensor is loaded for T5, but never used
|
||||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||||
|
|
@ -68,7 +68,7 @@ index f2bc8ca7..5ab3f572 100644
|
||||||
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||||
index 41a023da..525c1b7d 100644
|
index 930cb4ec..591bc14e 100644
|
||||||
--- a/src/llama-arch.h
|
--- a/src/llama-arch.h
|
||||||
+++ b/src/llama-arch.h
|
+++ b/src/llama-arch.h
|
||||||
@@ -73,6 +73,7 @@ enum llm_arch {
|
@@ -73,6 +73,7 @@ enum llm_arch {
|
||||||
|
|
@ -87,7 +87,7 @@ index 41a023da..525c1b7d 100644
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
|
||||||
@@ -346,6 +348,7 @@ enum llm_tensor {
|
@@ -348,6 +350,7 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||||
LLM_TENSOR_CLS,
|
LLM_TENSOR_CLS,
|
||||||
LLM_TENSOR_CLS_OUT,
|
LLM_TENSOR_CLS_OUT,
|
||||||
|
|
@ -96,10 +96,10 @@ index 41a023da..525c1b7d 100644
|
||||||
LLM_TENSOR_CONVNEXT_DW,
|
LLM_TENSOR_CONVNEXT_DW,
|
||||||
LLM_TENSOR_CONVNEXT_NORM,
|
LLM_TENSOR_CONVNEXT_NORM,
|
||||||
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
|
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
|
||||||
index 90dfe7a7..8a667960 100644
|
index 1499eb08..aa7a4b23 100644
|
||||||
--- a/src/llama-hparams.cpp
|
--- a/src/llama-hparams.cpp
|
||||||
+++ b/src/llama-hparams.cpp
|
+++ b/src/llama-hparams.cpp
|
||||||
@@ -70,6 +70,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -113,12 +113,12 @@ index 90dfe7a7..8a667960 100644
|
||||||
+
|
+
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
|
return swa_layers[il];
|
||||||
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
|
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
|
||||||
index 7ee6a5b7..48dce407 100644
|
index b2bcb8b0..347d239d 100644
|
||||||
--- a/src/llama-hparams.h
|
--- a/src/llama-hparams.h
|
||||||
+++ b/src/llama-hparams.h
|
+++ b/src/llama-hparams.h
|
||||||
@@ -55,6 +55,8 @@ struct llama_hparams {
|
@@ -59,6 +59,8 @@ struct llama_hparams {
|
||||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||||
|
|
||||||
|
|
@ -127,7 +127,7 @@ index 7ee6a5b7..48dce407 100644
|
||||||
uint32_t n_layer_dense_lead = 0;
|
uint32_t n_layer_dense_lead = 0;
|
||||||
uint32_t n_lora_q = 0;
|
uint32_t n_lora_q = 0;
|
||||||
uint32_t n_lora_kv = 0;
|
uint32_t n_lora_kv = 0;
|
||||||
@@ -154,6 +156,9 @@ struct llama_hparams {
|
@@ -186,6 +188,9 @@ struct llama_hparams {
|
||||||
// dimension of the recurrent state embeddings
|
// dimension of the recurrent state embeddings
|
||||||
uint32_t n_embd_v_s() const;
|
uint32_t n_embd_v_s() const;
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ index 7ee6a5b7..48dce407 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
|
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
|
||||||
index 4cce5166..7f6617fa 100644
|
index ddb1b036..f4a6c2cd 100644
|
||||||
--- a/src/llama-model-loader.cpp
|
--- a/src/llama-model-loader.cpp
|
||||||
+++ b/src/llama-model-loader.cpp
|
+++ b/src/llama-model-loader.cpp
|
||||||
@@ -439,6 +439,7 @@ namespace GGUFMeta {
|
@@ -439,6 +439,7 @@ namespace GGUFMeta {
|
||||||
|
|
@ -150,10 +150,10 @@ index 4cce5166..7f6617fa 100644
|
||||||
llama_model_loader::llama_model_loader(
|
llama_model_loader::llama_model_loader(
|
||||||
const std::string & fname,
|
const std::string & fname,
|
||||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||||
index 3a4e72a3..831b68c0 100644
|
index afef8487..c042546c 100644
|
||||||
--- a/src/llama-model.cpp
|
--- a/src/llama-model.cpp
|
||||||
+++ b/src/llama-model.cpp
|
+++ b/src/llama-model.cpp
|
||||||
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
@@ -1417,6 +1417,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -175,7 +175,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@@ -3774,6 +3789,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
@@ -3797,6 +3812,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
|
@ -210,7 +210,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
@@ -12397,6 +12440,165 @@ struct llm_build_chameleon : public llm_graph_context {
|
@@ -12721,6 +12764,165 @@ struct llm_build_chameleon : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -270,7 +270,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
+ // self-attention
|
+ // self-attention
|
||||||
+ {
|
+ {
|
||||||
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
|
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||||
+
|
+
|
||||||
+ // compute Q and K and RoPE them
|
+ // compute Q and K and RoPE them
|
||||||
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
|
@ -376,7 +376,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
struct llm_build_wavtokenizer_dec : public llm_graph_context {
|
struct llm_build_wavtokenizer_dec : public llm_graph_context {
|
||||||
llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
@@ -13157,6 +13359,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
@@ -13515,6 +13717,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -387,7 +387,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
|
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
|
||||||
@@ -13301,6 +13507,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
@@ -13663,6 +13869,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
case LLM_ARCH_GRANITE_MOE:
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
|
|
@ -396,7 +396,7 @@ index 3a4e72a3..831b68c0 100644
|
||||||
return LLAMA_ROPE_TYPE_NORM;
|
return LLAMA_ROPE_TYPE_NORM;
|
||||||
|
|
||||||
diff --git a/src/llama-model.h b/src/llama-model.h
|
diff --git a/src/llama-model.h b/src/llama-model.h
|
||||||
index 6bdec263..43746c7d 100644
|
index cbea2cb3..43e7fcda 100644
|
||||||
--- a/src/llama-model.h
|
--- a/src/llama-model.h
|
||||||
+++ b/src/llama-model.h
|
+++ b/src/llama-model.h
|
||||||
@@ -65,6 +65,7 @@ enum llm_type {
|
@@ -65,6 +65,7 @@ enum llm_type {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ regex
|
||||||
2 files changed, 22 insertions(+), 1 deletion(-)
|
2 files changed, 22 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 806c1b3d..10f34d33 100644
|
index 0d7ad157..d007039f 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -298,7 +298,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
@@ -298,7 +298,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] maintain ordering for rules for grammar
|
||||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
|
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
|
||||||
index 5b3059c2..656b3eca 100644
|
index d38a74f9..2a8aeca6 100644
|
||||||
--- a/common/json-schema-to-grammar.cpp
|
--- a/common/json-schema-to-grammar.cpp
|
||||||
+++ b/common/json-schema-to-grammar.cpp
|
+++ b/common/json-schema-to-grammar.cpp
|
||||||
@@ -349,7 +349,7 @@ private:
|
@@ -350,7 +350,7 @@ private:
|
||||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
|
|
|
||||||
|
|
@ -1,352 +0,0 @@
|
||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: jmorganca <jmorganca@gmail.com>
|
|
||||||
Date: Tue, 15 Apr 2025 14:27:40 -0400
|
|
||||||
Subject: [PATCH] ensure KV cache is fully defragmented
|
|
||||||
|
|
||||||
Sometimes the KV cache requires defragmentation even without
|
|
||||||
triggering the threshold heuristic. In this case, decoding
|
|
||||||
will not being able to find a KV cache slot. This is particularly
|
|
||||||
difficult for the caller to handle if it happens in between
|
|
||||||
ubatches. To avoid this, we should immediately trigger a defrag.
|
|
||||||
|
|
||||||
In addition, a heavily fragmented cache can require more than
|
|
||||||
max_moves to defragment. Currently, we stop when we hit the limit
|
|
||||||
but this can leave a cache that still does not have adequate space
|
|
||||||
even after defragmentation is triggered. Instead, we should do
|
|
||||||
multiple batches of processing until everything is complete.
|
|
||||||
---
|
|
||||||
src/llama-context.cpp | 18 ++++---
|
|
||||||
src/llama-context.h | 1 +
|
|
||||||
src/llama-kv-cache.cpp | 107 ++++++++++++++---------------------------
|
|
||||||
src/llama-kv-cache.h | 12 ++++-
|
|
||||||
4 files changed, 59 insertions(+), 79 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
|
||||||
index c22687e4..c5948e8f 100644
|
|
||||||
--- a/src/llama-context.cpp
|
|
||||||
+++ b/src/llama-context.cpp
|
|
||||||
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
||||||
|
|
||||||
// find KV slot
|
|
||||||
if (!kv_self->find_slot(ubatch)) {
|
|
||||||
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
||||||
-
|
|
||||||
- return 1;
|
|
||||||
+ kv_self->defrag_sched(-1.0f);
|
|
||||||
+ kv_self->update(*this);
|
|
||||||
+ if (!kv_self->find_slot(ubatch)) {
|
|
||||||
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
|
||||||
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
|
|
||||||
|
|
||||||
// TODO: not sure if this is needed
|
|
||||||
if (!kv_self->find_slot(ubatch)) {
|
|
||||||
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
||||||
-
|
|
||||||
- GGML_ABORT("TODO: handle this error");
|
|
||||||
+ kv_self->defrag_sched(-1.0f);
|
|
||||||
+ kv_self->update(*this);
|
|
||||||
+ if (!kv_self->find_slot(ubatch)) {
|
|
||||||
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
||||||
+ GGML_ABORT("TODO: handle this error");
|
|
||||||
+ }
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * gf = graph_init();
|
|
||||||
diff --git a/src/llama-context.h b/src/llama-context.h
|
|
||||||
index c0ceacb1..0264e937 100644
|
|
||||||
--- a/src/llama-context.h
|
|
||||||
+++ b/src/llama-context.h
|
|
||||||
@@ -5,6 +5,7 @@
|
|
||||||
#include "llama-cparams.h"
|
|
||||||
#include "llama-graph.h"
|
|
||||||
#include "llama-adapter.h"
|
|
||||||
+#include "llama-kv-cache.h"
|
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
|
||||||
#include "ggml-opt.h"
|
|
||||||
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
|
|
||||||
index 3dcad65b..60e67b03 100644
|
|
||||||
--- a/src/llama-kv-cache.cpp
|
|
||||||
+++ b/src/llama-kv-cache.cpp
|
|
||||||
@@ -364,8 +364,6 @@ void llama_kv_cache_unified::commit() {
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
||||||
- bool need_reserve = false;
|
|
||||||
-
|
|
||||||
auto * sched = lctx.get_sched();
|
|
||||||
|
|
||||||
if (has_shift) {
|
|
||||||
@@ -388,8 +386,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
||||||
res->set_inputs(nullptr);
|
|
||||||
|
|
||||||
lctx.graph_compute(gf, false);
|
|
||||||
-
|
|
||||||
- need_reserve = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
@@ -403,27 +399,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
||||||
|
|
||||||
if (do_defrag) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
|
||||||
+ const uint32_t n_max_nodes = lctx.graph_max_nodes();
|
|
||||||
+ const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
|
|
||||||
+ if (!defrag_prepare(n_max_nodes)) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
|
|
||||||
+ return false;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) {
|
|
||||||
+ std::vector<struct llama_kv_defrag_move> chunk;
|
|
||||||
+ auto end = std::min(i + max_moves, defrag_info.moves.size());
|
|
||||||
+ chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end);
|
|
||||||
|
|
||||||
- if (defrag_prepare(lctx.graph_max_nodes())) {
|
|
||||||
ggml_backend_sched_reset(sched);
|
|
||||||
|
|
||||||
auto * gf = lctx.graph_init();
|
|
||||||
|
|
||||||
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
|
||||||
+ auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk);
|
|
||||||
|
|
||||||
ggml_backend_sched_alloc_graph(sched, gf);
|
|
||||||
|
|
||||||
res->set_inputs(nullptr);
|
|
||||||
|
|
||||||
lctx.graph_compute(gf, false);
|
|
||||||
-
|
|
||||||
- need_reserve = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
do_defrag = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
- return need_reserve;
|
|
||||||
+ // we never need to reserve a worst case graph
|
|
||||||
+ return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
|
||||||
@@ -707,11 +712,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
||||||
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_context * ctx,
|
|
||||||
- ggml_cgraph * gf) const {
|
|
||||||
+ ggml_cgraph * gf,
|
|
||||||
+ const std::vector<struct llama_kv_defrag_move> & moves) const {
|
|
||||||
auto res = std::make_unique<llm_graph_result>();
|
|
||||||
|
|
||||||
- const auto & ids = defrag_info.ids;
|
|
||||||
-
|
|
||||||
#if 0
|
|
||||||
// CPU defrag
|
|
||||||
//
|
|
||||||
@@ -783,32 +787,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
||||||
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
- for (uint32_t i = 0; i < ids.size(); ++i) {
|
|
||||||
- const uint32_t id = ids[i];
|
|
||||||
-
|
|
||||||
- if (i == id || id == ids.size()) {
|
|
||||||
- continue;
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
- uint32_t nm = 1;
|
|
||||||
-
|
|
||||||
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
|
||||||
- nm++;
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
+ for (const auto & move : moves) {
|
|
||||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
|
|
||||||
ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
|
|
||||||
- n_embd_k_gqa, nm,
|
|
||||||
+ n_embd_k_gqa, move.len,
|
|
||||||
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
|
|
||||||
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
|
|
||||||
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src));
|
|
||||||
|
|
||||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
|
|
||||||
- n_embd_k_gqa, nm,
|
|
||||||
+ n_embd_k_gqa, move.len,
|
|
||||||
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
|
|
||||||
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
|
|
||||||
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst));
|
|
||||||
|
|
||||||
ggml_tensor * view_v_src;
|
|
||||||
ggml_tensor * view_v_dst;
|
|
||||||
@@ -816,31 +808,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
||||||
if (cparams.flash_attn) {
|
|
||||||
// NOTE: the V cache is not transposed when using flash attention
|
|
||||||
view_v_src = ggml_view_2d(ctx, v_l[il],
|
|
||||||
- n_embd_v_gqa, nm,
|
|
||||||
+ n_embd_v_gqa, move.len,
|
|
||||||
ggml_row_size(v_l[il]->type, n_embd_v_gqa),
|
|
||||||
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
|
|
||||||
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst));
|
|
||||||
|
|
||||||
view_v_dst = ggml_view_2d(ctx, v_l[il],
|
|
||||||
- n_embd_v_gqa, nm,
|
|
||||||
+ move.len, n_embd_v_gqa,
|
|
||||||
ggml_row_size(v_l[il]->type, n_embd_v_gqa),
|
|
||||||
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
|
|
||||||
+ ggml_row_size(v_l[il]->type, move.src));
|
|
||||||
} else {
|
|
||||||
view_v_src = ggml_view_2d(ctx, v_l[il],
|
|
||||||
- nm, n_embd_v_gqa,
|
|
||||||
+ move.len, n_embd_v_gqa,
|
|
||||||
ggml_row_size(v_l[il]->type, size),
|
|
||||||
- ggml_row_size(v_l[il]->type, i));
|
|
||||||
+ ggml_row_size(v_l[il]->type, move.src));
|
|
||||||
|
|
||||||
view_v_dst = ggml_view_2d(ctx, v_l[il],
|
|
||||||
- nm, n_embd_v_gqa,
|
|
||||||
+ move.len, n_embd_v_gqa,
|
|
||||||
ggml_row_size(v_l[il]->type, size),
|
|
||||||
- ggml_row_size(v_l[il]->type, id));
|
|
||||||
+ ggml_row_size(v_l[il]->type, move.dst));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
|
|
||||||
}
|
|
||||||
-
|
|
||||||
- i += nm - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
|
||||||
@@ -857,17 +847,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
||||||
|
|
||||||
assert(n_used <= n_kv);
|
|
||||||
|
|
||||||
- //const int64_t t_start = ggml_time_us();
|
|
||||||
-
|
|
||||||
- // number of cells moved
|
|
||||||
- uint32_t n_moves = 0;
|
|
||||||
-
|
|
||||||
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
|
||||||
- // - source view, destination view, copy operation
|
|
||||||
- // - x2 for keys and values
|
|
||||||
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
|
|
||||||
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
|
||||||
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
|
||||||
+ defrag_info.moves.clear();
|
|
||||||
|
|
||||||
// determine which KV cells to move where
|
|
||||||
//
|
|
||||||
@@ -875,10 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
||||||
//
|
|
||||||
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
|
||||||
//
|
|
||||||
- auto & ids = defrag_info.ids;
|
|
||||||
-
|
|
||||||
- ids.clear();
|
|
||||||
- ids.resize(n_kv, n_kv);
|
|
||||||
+ std::vector<uint32_t> ids(n_kv, n_kv);
|
|
||||||
|
|
||||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
|
||||||
const auto & cell0 = cells[i0];
|
|
||||||
@@ -927,19 +904,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
||||||
// are we moving a continuous block of memory?
|
|
||||||
bool cont = false;
|
|
||||||
|
|
||||||
- // should we stop searching for the next move?
|
|
||||||
- bool stop = false;
|
|
||||||
-
|
|
||||||
// go back and move the nf cells to the hole
|
|
||||||
for (; i1 < n_kv; ++i1) {
|
|
||||||
auto & cell1 = cells[i1];
|
|
||||||
|
|
||||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
|
||||||
- if (n_moves == max_moves) {
|
|
||||||
- stop = true;
|
|
||||||
- break;
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
cont = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
@@ -955,8 +924,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
||||||
head = n_used;
|
|
||||||
|
|
||||||
if (!cont) {
|
|
||||||
- n_moves++;
|
|
||||||
+ defrag_info.moves.push_back({i1, i0 + nf, 1});
|
|
||||||
cont = true;
|
|
||||||
+ } else {
|
|
||||||
+ defrag_info.moves.back().len++;
|
|
||||||
}
|
|
||||||
|
|
||||||
nf++;
|
|
||||||
@@ -966,22 +937,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
- if (stop || n_moves == max_moves) {
|
|
||||||
- break;
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
|
||||||
|
|
||||||
i0 += nh - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
- if (n_moves == 0) {
|
|
||||||
+ if (defrag_info.moves.size() == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
- LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
|
||||||
-
|
|
||||||
- LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
|
||||||
+ // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
|
|
||||||
index bf3b4b6a..928b9712 100644
|
|
||||||
--- a/src/llama-kv-cache.h
|
|
||||||
+++ b/src/llama-kv-cache.h
|
|
||||||
@@ -82,6 +82,13 @@ struct llama_kv_cache_guard {
|
|
||||||
private:
|
|
||||||
llama_kv_cache * kv;
|
|
||||||
};
|
|
||||||
+
|
|
||||||
+// block of KV slots to move when defragging
|
|
||||||
+struct llama_kv_defrag_move {
|
|
||||||
+ uint32_t src;
|
|
||||||
+ uint32_t dst;
|
|
||||||
+ uint32_t len;
|
|
||||||
+};
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache_unified
|
|
||||||
@@ -207,7 +214,7 @@ private:
|
|
||||||
|
|
||||||
// defrag
|
|
||||||
struct {
|
|
||||||
- std::vector<uint32_t> ids;
|
|
||||||
+ std::vector<llama_kv_defrag_move> moves;
|
|
||||||
} defrag_info;
|
|
||||||
|
|
||||||
// return true if cells have been moved
|
|
||||||
@@ -249,7 +256,8 @@ private:
|
|
||||||
llm_graph_result_ptr build_graph_defrag(
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_context * ctx,
|
|
||||||
- ggml_cgraph * gf) const;
|
|
||||||
+ ggml_cgraph * gf,
|
|
||||||
+ const std::vector<llama_kv_defrag_move> & moves) const;
|
|
||||||
|
|
||||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
||||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
|
||||||
1 file changed, 2 insertions(+)
|
1 file changed, 2 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
index ddea5ad3..45918bf6 100644
|
index 7dcb031f..770e18bc 100644
|
||||||
--- a/ggml/src/CMakeLists.txt
|
--- a/ggml/src/CMakeLists.txt
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
@@ -279,6 +279,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
@@ -282,6 +282,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||||
|
|
@ -19,11 +19,11 @@ index ddea5ad3..45918bf6 100644
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
ggml_add_backend(CPU)
|
ggml_add_backend(CPU)
|
||||||
@@ -287,6 +288,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
@@ -290,6 +291,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
if (NOT GGML_BACKEND_DL)
|
if (NOT GGML_BACKEND_DL)
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||||
endif()
|
endif()
|
||||||
+ add_custom_target(ggml-cpu)
|
+ add_custom_target(ggml-cpu)
|
||||||
ggml_add_cpu_backend_variant(x64)
|
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
ggml_add_cpu_backend_variant(sse42 SSE42)
|
ggml_add_cpu_backend_variant(x64)
|
||||||
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
ggml_add_cpu_backend_variant(sse42 SSE42)
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: jmorganca <jmorganca@gmail.com>
|
||||||
|
Date: Thu, 1 May 2025 15:05:08 -0700
|
||||||
|
Subject: [PATCH] remove amx
|
||||||
|
|
||||||
|
disable amx as it reduces performance on some systems
|
||||||
|
---
|
||||||
|
ggml/src/CMakeLists.txt | 4 ----
|
||||||
|
1 file changed, 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
|
index 770e18bc..62f3dbf6 100644
|
||||||
|
--- a/ggml/src/CMakeLists.txt
|
||||||
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
|
@@ -300,10 +300,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
|
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
||||||
|
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||||
|
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||||
|
- if (NOT MSVC)
|
||||||
|
- # MSVC doesn't support AMX
|
||||||
|
- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||||
|
- endif()
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported on ${GGML_SYSTEM_ARCH}")
|
||||||
|
endif()
|
||||||
|
|
@ -25,10 +25,10 @@ index 79ee2020..3efb22f0 100644
|
||||||
// get ith C string from array with given key_id
|
// get ith C string from array with given key_id
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||||
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
||||||
index 381a9c7d..e45b453d 100644
|
index a0a318a2..b3326b94 100644
|
||||||
--- a/ggml/src/gguf.cpp
|
--- a/ggml/src/gguf.cpp
|
||||||
+++ b/ggml/src/gguf.cpp
|
+++ b/ggml/src/gguf.cpp
|
||||||
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
@@ -794,10 +794,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
||||||
|
|
||||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
|
|
@ -44,7 +44,7 @@ index 381a9c7d..e45b453d 100644
|
||||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
||||||
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
@@ -891,7 +895,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
||||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
||||||
|
|
@ -53,7 +53,7 @@ index 381a9c7d..e45b453d 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 10f34d33..9f5fd57b 100644
|
index d007039f..4a6c3ad6 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -1469,9 +1469,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
@@ -1469,9 +1469,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: jmorganca <jmorganca@gmail.com>
|
|
||||||
Date: Thu, 1 May 2025 15:05:08 -0700
|
|
||||||
Subject: [PATCH] remove amx
|
|
||||||
|
|
||||||
disable amx as it reduces performance on some systems
|
|
||||||
---
|
|
||||||
ggml/src/CMakeLists.txt | 4 ----
|
|
||||||
1 file changed, 4 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
|
||||||
index 45918bf6..0beaed86 100644
|
|
||||||
--- a/ggml/src/CMakeLists.txt
|
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
|
||||||
@@ -296,10 +296,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
|
||||||
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
|
||||||
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
|
||||||
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
|
||||||
- if (NOT MSVC)
|
|
||||||
- # MSVC doesn't support AMX
|
|
||||||
- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
|
||||||
- endif()
|
|
||||||
elseif (GGML_CPU)
|
|
||||||
ggml_add_cpu_backend_variant_impl("")
|
|
||||||
endif()
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor
|
||||||
1 file changed, 6 insertions(+)
|
1 file changed, 6 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
index a30e67f2..2462d2b8 100644
|
index c7426df2..23441678 100644
|
||||||
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
@@ -15,6 +15,8 @@
|
@@ -15,6 +15,8 @@
|
||||||
|
|
@ -20,7 +20,7 @@ index a30e67f2..2462d2b8 100644
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||||
@@ -2841,6 +2843,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
@@ -2873,6 +2875,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
|
|
@ -10,7 +10,7 @@ Subject: [PATCH] add ollama vocab for grammar support
|
||||||
3 files changed, 58 insertions(+), 9 deletions(-)
|
3 files changed, 58 insertions(+), 9 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
||||||
index 973b47ae..60d58236 100644
|
index bed706bb..b51cee09 100644
|
||||||
--- a/src/llama-grammar.cpp
|
--- a/src/llama-grammar.cpp
|
||||||
+++ b/src/llama-grammar.cpp
|
+++ b/src/llama-grammar.cpp
|
||||||
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
@ -90,7 +90,7 @@ index 973b47ae..60d58236 100644
|
||||||
|
|
||||||
if (grammar.awaiting_trigger) {
|
if (grammar.awaiting_trigger) {
|
||||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
@@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
@@ -1201,13 +1210,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -107,7 +107,7 @@ index 973b47ae..60d58236 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_str(grammar, piece);
|
||||||
@@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
@@ -1227,3 +1237,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -184,7 +184,7 @@ index f8c291de..2a3a62db 100644
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
||||||
index 804b11e0..15a10ca8 100644
|
index bfbf5fa2..11f93f42 100644
|
||||||
--- a/src/llama-sampling.cpp
|
--- a/src/llama-sampling.cpp
|
||||||
+++ b/src/llama-sampling.cpp
|
+++ b/src/llama-sampling.cpp
|
||||||
@@ -1466,7 +1466,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
@@ -1466,7 +1466,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
|
|
@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
|
||||||
3 files changed, 192 insertions(+), 2 deletions(-)
|
3 files changed, 192 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||||
index becdae07..7a44b6cf 100644
|
index 08facb6d..aa5cf56b 100644
|
||||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||||
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
|
@@ -6925,6 +6925,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -59,7 +59,7 @@ index becdae07..7a44b6cf 100644
|
||||||
void ggml_compute_forward_argsort(
|
void ggml_compute_forward_argsort(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
|
@@ -6936,6 +6975,10 @@ void ggml_compute_forward_argsort(
|
||||||
{
|
{
|
||||||
ggml_compute_forward_argsort_f32(params, dst);
|
ggml_compute_forward_argsort_f32(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -195,10 +195,10 @@ index 607ded85..53b02634 100644
|
||||||
+ }
|
+ }
|
||||||
}
|
}
|
||||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||||
index 2d46176e..47383486 100644
|
index 2c55d214..90d95d32 100644
|
||||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||||
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
@@ -41,6 +41,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||||
*dsti = *xi;
|
*dsti = *xi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -212,7 +212,7 @@ index 2d46176e..47383486 100644
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
|
@@ -71,6 +78,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
|
||||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,7 +257,7 @@ index 2d46176e..47383486 100644
|
||||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||||
const float * xi = (const float *) cxi;
|
const float * xi = (const float *) cxi;
|
||||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||||
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
@@ -643,6 +688,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||||
|
|
@ -266,7 +266,7 @@ index 2d46176e..47383486 100644
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
@@ -698,6 +745,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||||
|
|
@ -134,10 +134,10 @@ index 5fd379f6..04812990 100644
|
||||||
|
|
||||||
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
index 0ce73a99..be335e8c 100644
|
index e8694e5c..36f11537 100644
|
||||||
--- a/ggml/src/ggml-backend.cpp
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
+++ b/ggml/src/ggml-backend.cpp
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
@@ -1637,6 +1637,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||||
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -24,10 +24,10 @@ index 74e46716..a880df33 100644
|
||||||
size_t memory_total;
|
size_t memory_total;
|
||||||
enum ggml_backend_dev_type type;
|
enum ggml_backend_dev_type type;
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index cb0d8528..4c829153 100644
|
index ec031650..8d5edd04 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
|
@@ -2893,6 +2893,7 @@ struct ggml_backend_cuda_device_context {
|
||||||
int device;
|
int device;
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
|
|
@ -35,7 +35,7 @@ index cb0d8528..4c829153 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||||
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
@@ -2905,6 +2906,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||||
return ctx->description.c_str();
|
return ctx->description.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +47,7 @@ index cb0d8528..4c829153 100644
|
||||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
ggml_cuda_set_device(ctx->device);
|
ggml_cuda_set_device(ctx->device);
|
||||||
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
@@ -2919,6 +2925,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||||
|
|
@ -55,7 +55,7 @@ index cb0d8528..4c829153 100644
|
||||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
|
||||||
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
@@ -3473,6 +3480,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||||
dev_ctx->description = prop.name;
|
dev_ctx->description = prop.name;
|
||||||
|
|
||||||
|
|
@ -89,10 +89,10 @@ index cb0d8528..4c829153 100644
|
||||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||||
/* .reg = */ ®,
|
/* .reg = */ ®,
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
index 1b56f858..ee4f2dcb 100644
|
index fd3a9d1b..884bde80 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
@@ -5761,6 +5761,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||||
props->name = ggml_backend_metal_device_get_name(dev);
|
props->name = ggml_backend_metal_device_get_name(dev);
|
||||||
props->description = ggml_backend_metal_device_get_description(dev);
|
props->description = ggml_backend_metal_device_get_description(dev);
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,8 @@ extern "C" {
|
||||||
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
||||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||||
|
|
||||||
|
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
||||||
|
|
||||||
// get underlying tensors that store data
|
// get underlying tensors that store data
|
||||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||||
|
|
|
||||||
|
|
@ -536,6 +536,7 @@ extern "C" {
|
||||||
GGML_UNARY_OP_HARDSWISH,
|
GGML_UNARY_OP_HARDSWISH,
|
||||||
GGML_UNARY_OP_HARDSIGMOID,
|
GGML_UNARY_OP_HARDSIGMOID,
|
||||||
GGML_UNARY_OP_EXP,
|
GGML_UNARY_OP_EXP,
|
||||||
|
GGML_UNARY_OP_GELU_ERF,
|
||||||
|
|
||||||
GGML_UNARY_OP_COUNT,
|
GGML_UNARY_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
@ -934,6 +935,15 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// repeat a to the specified shape
|
||||||
|
GGML_API struct ggml_tensor * ggml_repeat_4d(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
int64_t ne3);
|
||||||
|
|
||||||
// sums repetitions in a into shape of b
|
// sums repetitions in a into shape of b
|
||||||
GGML_API struct ggml_tensor * ggml_repeat_back(
|
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
@ -1024,6 +1034,16 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// GELU using erf (error function) when possible
|
||||||
|
// some backends may fallback to approximation based on Abramowitz and Stegun formula
|
||||||
|
GGML_API struct ggml_tensor * ggml_gelu_erf(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_gelu_quick(
|
GGML_API struct ggml_tensor * ggml_gelu_quick(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
@ -2075,9 +2095,6 @@ extern "C" {
|
||||||
GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
|
GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
|
||||||
GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
|
GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
|
||||||
|
|
||||||
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
|
|
||||||
GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
|
|
||||||
|
|
||||||
// print info and performance information for the graph
|
// print info and performance information for the graph
|
||||||
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
|
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
|
|
@ -2161,6 +2178,7 @@ extern "C" {
|
||||||
|
|
||||||
// scheduling priorities
|
// scheduling priorities
|
||||||
enum ggml_sched_priority {
|
enum ggml_sched_priority {
|
||||||
|
GGML_SCHED_PRIO_LOW = -1,
|
||||||
GGML_SCHED_PRIO_NORMAL,
|
GGML_SCHED_PRIO_NORMAL,
|
||||||
GGML_SCHED_PRIO_MEDIUM,
|
GGML_SCHED_PRIO_MEDIUM,
|
||||||
GGML_SCHED_PRIO_HIGH,
|
GGML_SCHED_PRIO_HIGH,
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,8 @@ if (MSVC)
|
||||||
else ()
|
else ()
|
||||||
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
||||||
endif ()
|
endif ()
|
||||||
|
ggml_get_system_arch()
|
||||||
|
message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
|
||||||
|
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
if (GGML_STATIC)
|
if (GGML_STATIC)
|
||||||
|
|
@ -123,7 +125,6 @@ if (NOT MSVC)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MINGW)
|
if (MINGW)
|
||||||
# Target Windows 8 for PrefetchVirtualMemory
|
|
||||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
@ -194,6 +195,7 @@ add_library(ggml-base
|
||||||
../include/ggml-opt.h
|
../include/ggml-opt.h
|
||||||
../include/gguf.h
|
../include/gguf.h
|
||||||
ggml.c
|
ggml.c
|
||||||
|
ggml.cpp
|
||||||
ggml-alloc.c
|
ggml-alloc.c
|
||||||
ggml-backend.cpp
|
ggml-backend.cpp
|
||||||
ggml-opt.cpp
|
ggml-opt.cpp
|
||||||
|
|
@ -224,6 +226,7 @@ function(ggml_add_backend_library backend)
|
||||||
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
|
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
|
||||||
add_dependencies(ggml ${backend})
|
add_dependencies(ggml ${backend})
|
||||||
|
install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||||
else()
|
else()
|
||||||
add_library(${backend} ${ARGN})
|
add_library(${backend} ${ARGN})
|
||||||
target_link_libraries(ggml PUBLIC ${backend})
|
target_link_libraries(ggml PUBLIC ${backend})
|
||||||
|
|
@ -289,13 +292,17 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||||
endif()
|
endif()
|
||||||
add_custom_target(ggml-cpu)
|
add_custom_target(ggml-cpu)
|
||||||
ggml_add_cpu_backend_variant(x64)
|
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
ggml_add_cpu_backend_variant(sse42 SSE42)
|
ggml_add_cpu_backend_variant(x64)
|
||||||
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
ggml_add_cpu_backend_variant(sse42 SSE42)
|
||||||
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
|
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
||||||
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
|
||||||
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
||||||
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||||
|
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported on ${GGML_SYSTEM_ARCH}")
|
||||||
|
endif()
|
||||||
elseif (GGML_CPU)
|
elseif (GGML_CPU)
|
||||||
ggml_add_cpu_backend_variant_impl("")
|
ggml_add_cpu_backend_variant_impl("")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -1340,7 +1340,10 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
||||||
// allocate graph
|
// allocate graph
|
||||||
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||||
// the re-allocation may cause the split inputs to be moved to a different address
|
// the re-allocation may cause the split inputs to be moved to a different address
|
||||||
ggml_backend_sched_synchronize(sched);
|
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
|
||||||
|
for (int i = 0; i < sched->n_backends; i++) {
|
||||||
|
ggml_backend_synchronize(sched->backends[i]);
|
||||||
|
}
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -1564,7 +1567,6 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra
|
||||||
|
|
||||||
ggml_backend_sched_split_graph(sched, graph);
|
ggml_backend_sched_split_graph(sched, graph);
|
||||||
|
|
||||||
|
|
||||||
if (!ggml_backend_sched_alloc_splits(sched)) {
|
if (!ggml_backend_sched_alloc_splits(sched)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -1598,6 +1600,12 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
|
||||||
for (int i = 0; i < sched->n_backends; i++) {
|
for (int i = 0; i < sched->n_backends; i++) {
|
||||||
ggml_backend_synchronize(sched->backends[i]);
|
ggml_backend_synchronize(sched->backends[i]);
|
||||||
}
|
}
|
||||||
|
if (!sched->is_alloc) {
|
||||||
|
// if the graph is not already allocated, always use copy 0 after a synchronization
|
||||||
|
// this ensures that during generation the same copy is used every time,
|
||||||
|
// which avoids changes in the graph that could cause CUDA or other graphs to be disabled
|
||||||
|
sched->cur_copy = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
|
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ if (BLAS_FOUND)
|
||||||
target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
|
target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
|
||||||
target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
else()
|
else()
|
||||||
message(ERROR "BLAS not found, please refer to "
|
message(FATAL_ERROR "BLAS not found, please refer to "
|
||||||
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
|
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
|
||||||
" to set correct GGML_BLAS_VENDOR")
|
" to set correct GGML_BLAS_VENDOR")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -82,13 +82,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind)
|
target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
if (GGML_SYSTEM_ARCH STREQUAL "ARM")
|
||||||
CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
|
||||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
||||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
|
||||||
|
|
||||||
message(STATUS "ARM detected")
|
message(STATUS "ARM detected")
|
||||||
|
|
||||||
if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||||
message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
|
message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
|
||||||
else()
|
else()
|
||||||
|
|
@ -170,12 +165,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
||||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$"))
|
|
||||||
|
|
||||||
message(STATUS "x86 detected")
|
message(STATUS "x86 detected")
|
||||||
|
|
||||||
if (MSVC)
|
if (MSVC)
|
||||||
# instruction set detection for MSVC only
|
# instruction set detection for MSVC only
|
||||||
if (GGML_NATIVE)
|
if (GGML_NATIVE)
|
||||||
|
|
@ -299,7 +290,26 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
|
||||||
|
if (GGML_BACKEND_DL)
|
||||||
|
if (GGML_NATIVE)
|
||||||
|
# the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE
|
||||||
|
message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# The feature detection code is compiled as a separate target so that
|
||||||
|
# it can be built without the architecture flags
|
||||||
|
# Since multiple variants of the CPU backend may be included in the same
|
||||||
|
# build, using set_source_files_properties() to set the arch flags is not possible
|
||||||
|
set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
|
||||||
|
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp)
|
||||||
|
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
|
||||||
|
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
||||||
|
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||||
|
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME})
|
||||||
|
endif()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||||
message(STATUS "PowerPC detected")
|
message(STATUS "PowerPC detected")
|
||||||
if (GGML_NATIVE)
|
if (GGML_NATIVE)
|
||||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||||
|
|
@ -308,7 +318,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
string(TOUPPER "${POWER10_M}" POWER10_M_UPPER)
|
||||||
|
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
|
||||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||||
|
|
||||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||||
|
|
@ -325,9 +336,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
|
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64")
|
||||||
message(STATUS "loongarch64 detected")
|
message(STATUS "loongarch64 detected")
|
||||||
|
|
||||||
list(APPEND ARCH_FLAGS -march=loongarch64)
|
list(APPEND ARCH_FLAGS -march=loongarch64)
|
||||||
if (GGML_LASX)
|
if (GGML_LASX)
|
||||||
list(APPEND ARCH_FLAGS -mlasx)
|
list(APPEND ARCH_FLAGS -mlasx)
|
||||||
|
|
@ -335,16 +345,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
if (GGML_LSX)
|
if (GGML_LSX)
|
||||||
list(APPEND ARCH_FLAGS -mlsx)
|
list(APPEND ARCH_FLAGS -mlsx)
|
||||||
endif()
|
endif()
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||||
message(STATUS "RISC-V detected")
|
message(STATUS "riscv64 detected")
|
||||||
if (GGML_RVV)
|
if (GGML_RVV)
|
||||||
if (GGML_RV_ZFH)
|
if (GGML_XTHEADVECTOR)
|
||||||
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d)
|
list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
|
||||||
|
elseif (GGML_RV_ZFH)
|
||||||
|
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
|
||||||
else()
|
else()
|
||||||
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
message(STATUS "s390x detected")
|
message(STATUS "s390x detected")
|
||||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||||
|
|
@ -385,9 +397,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
|
|
||||||
# Fetch KleidiAI sources:
|
# Fetch KleidiAI sources:
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
set(KLEIDIAI_COMMIT_TAG "v1.5.0")
|
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
|
||||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||||
set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
|
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
|
||||||
|
|
||||||
if (POLICY CMP0135)
|
if (POLICY CMP0135)
|
||||||
cmake_policy(SET CMP0135 NEW)
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
|
@ -477,25 +489,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
|
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
|
||||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
||||||
|
|
||||||
if (GGML_BACKEND_DL)
|
|
||||||
if (GGML_NATIVE)
|
|
||||||
# the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE
|
|
||||||
message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# The feature detection code is compiled as a separate target so that
|
|
||||||
# it can be built without the architecture flags
|
|
||||||
# Since multiple variants of the CPU backend may be included in the same
|
|
||||||
# build, using set_source_files_properties() to set the arch flags is not possible
|
|
||||||
set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
|
|
||||||
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp)
|
|
||||||
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
|
|
||||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
|
||||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
|
||||||
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (EMSCRIPTEN)
|
if (EMSCRIPTEN)
|
||||||
set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128")
|
set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -1191,7 +1191,7 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
#elif defined(__riscv_v_intrinsic)
|
#elif defined __riscv_v
|
||||||
if (__riscv_vlenb() >= QK4_0) {
|
if (__riscv_vlenb() >= QK4_0) {
|
||||||
const size_t vl = QK4_0;
|
const size_t vl = QK4_0;
|
||||||
|
|
||||||
|
|
@ -3783,7 +3783,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#elif defined(__riscv_v_intrinsic)
|
#elif defined __riscv_v
|
||||||
if (__riscv_vlenb() >= QK4_0) {
|
if (__riscv_vlenb() >= QK4_0) {
|
||||||
const size_t vl = QK4_0;
|
const size_t vl = QK4_0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -320,21 +320,17 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||||
|
|
||||||
#ifdef __wasm_simd128__
|
#ifdef __wasm_simd128__
|
||||||
#include <wasm_simd128.h>
|
#include <wasm_simd128.h>
|
||||||
#else
|
#endif
|
||||||
|
|
||||||
#ifdef __POWER9_VECTOR__
|
#ifdef __POWER9_VECTOR__
|
||||||
#include <altivec.h>
|
#include <altivec.h>
|
||||||
#else
|
#endif
|
||||||
|
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#include <intrin.h>
|
#include <intrin.h>
|
||||||
#else
|
#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
|
||||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
|
|
||||||
#if !defined(__riscv)
|
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __riscv_v_intrinsic
|
#ifdef __riscv_v_intrinsic
|
||||||
#include <riscv_vector.h>
|
#include <riscv_vector.h>
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue