323 lines
9.1 KiB
C++
323 lines
9.1 KiB
C++
//
|
|
// llm_demo.cpp
|
|
//
|
|
// Created by MNN on 2023/03/24.
|
|
// ZhaodeWang
|
|
//
|
|
|
|
#include "llm.hpp"
|
|
#define MNN_OPEN_TIME_TRACE
|
|
#include <MNN/AutoTime.hpp>
|
|
#include <MNN/expr/ExecutorScope.hpp>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <stdlib.h>
|
|
#include <initializer_list>
|
|
#ifdef LLM_SUPPORT_AUDIO
|
|
#include "audio/audio.hpp"
|
|
#endif
|
|
using namespace DEEPCAM::Transformer;
|
|
|
|
static void tuning_prepare(Llm *llm)
|
|
{
|
|
MNN_PRINT("Prepare for tuning opt Begin\n");
|
|
llm->tuning(OP_ENCODER_NUMBER, {1, 5, 10, 20, 30, 50, 100});
|
|
MNN_PRINT("Prepare for tuning opt End\n");
|
|
}
|
|
|
|
std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string> &lines)
|
|
{
|
|
std::vector<std::vector<std::string>> csv_data;
|
|
std::string line;
|
|
std::vector<std::string> row;
|
|
std::string cell;
|
|
bool insideQuotes = false;
|
|
bool startCollecting = false;
|
|
|
|
// content to stream
|
|
std::string content = "";
|
|
for (auto line : lines)
|
|
{
|
|
content = content + line + "\n";
|
|
}
|
|
std::istringstream stream(content);
|
|
|
|
while (stream.peek() != EOF)
|
|
{
|
|
char c = stream.get();
|
|
if (c == '"')
|
|
{
|
|
if (insideQuotes && stream.peek() == '"')
|
|
{ // quote
|
|
cell += '"';
|
|
stream.get(); // skip quote
|
|
}
|
|
else
|
|
{
|
|
insideQuotes = !insideQuotes; // start or end text in quote
|
|
}
|
|
startCollecting = true;
|
|
}
|
|
else if (c == ',' && !insideQuotes)
|
|
{ // end element, start new element
|
|
row.push_back(cell);
|
|
cell.clear();
|
|
startCollecting = false;
|
|
}
|
|
else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes)
|
|
{ // end line
|
|
row.push_back(cell);
|
|
csv_data.push_back(row);
|
|
cell.clear();
|
|
row.clear();
|
|
startCollecting = false;
|
|
}
|
|
else
|
|
{
|
|
cell += c;
|
|
startCollecting = true;
|
|
}
|
|
}
|
|
return csv_data;
|
|
}
|
|
|
|
static int benchmark(Llm *llm, const std::vector<std::string> &prompts, int max_token_number)
|
|
{
|
|
int prompt_len = 0;
|
|
int decode_len = 0;
|
|
int64_t vision_time = 0;
|
|
int64_t audio_time = 0;
|
|
int64_t prefill_time = 0;
|
|
int64_t decode_time = 0;
|
|
int64_t sample_time = 0;
|
|
// llm->warmup();
|
|
auto context = llm->getContext();
|
|
if (max_token_number > 0)
|
|
{
|
|
llm->set_config("{\"max_new_tokens\":1}");
|
|
}
|
|
#ifdef LLM_SUPPORT_AUDIO
|
|
std::vector<float> waveform;
|
|
llm->setWavformCallback([&](const float *ptr, size_t size, bool last_chunk)
|
|
{
|
|
waveform.reserve(waveform.size() + size);
|
|
waveform.insert(waveform.end(), ptr, ptr + size);
|
|
if (last_chunk) {
|
|
auto waveform_var = MNN::Express::_Const(waveform.data(), {(int)waveform.size()}, MNN::Express::NCHW, halide_type_of<float>());
|
|
MNN::AUDIO::save("output.wav", waveform_var, 24000);
|
|
waveform.clear();
|
|
}
|
|
return true; });
|
|
#endif
|
|
for (int i = 0; i < prompts.size(); i++)
|
|
{
|
|
const auto &prompt = prompts[i];
|
|
|
|
/**
|
|
update config.json and llm_config.json if need. example:
|
|
llm->set_config("{\"assistant_prompt_template\":\"<|im_start|>assistant\\n<think>\\n</think>\%s<|im_end|>\\n\"}");
|
|
*/
|
|
// prompt start with '#' will be ignored
|
|
if (prompt.substr(0, 1) == "#")
|
|
{
|
|
continue;
|
|
}
|
|
if (max_token_number > 0)
|
|
{
|
|
llm->response(prompt, &std::cout, nullptr, 0);
|
|
while (!llm->stoped() && context->gen_seq_len < max_token_number)
|
|
{
|
|
llm->generate(1);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
llm->response(prompt);
|
|
}
|
|
prompt_len += context->prompt_len;
|
|
decode_len += context->gen_seq_len;
|
|
vision_time += context->vision_us;
|
|
audio_time += context->audio_us;
|
|
prefill_time += context->prefill_us;
|
|
decode_time += context->decode_us;
|
|
sample_time += context->sample_us;
|
|
}
|
|
llm->generateWavform();
|
|
|
|
float vision_s = vision_time / 1e6;
|
|
float audio_s = audio_time / 1e6;
|
|
float prefill_s = prefill_time / 1e6;
|
|
float decode_s = decode_time / 1e6;
|
|
float sample_s = sample_time / 1e6;
|
|
printf("\n#################################\n");
|
|
printf("prompt tokens num = %d\n", prompt_len);
|
|
printf("decode tokens num = %d\n", decode_len);
|
|
printf(" vision time = %.2f s\n", vision_s);
|
|
printf(" audio time = %.2f s\n", audio_s);
|
|
printf("prefill time = %.2f s\n", prefill_s);
|
|
printf(" decode time = %.2f s\n", decode_s);
|
|
printf(" sample time = %.2f s\n", sample_s);
|
|
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
|
|
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
|
|
printf("##################################\n");
|
|
return 0;
|
|
}
|
|
|
|
static int ceval(Llm *llm, const std::vector<std::string> &lines, std::string filename)
|
|
{
|
|
auto csv_data = parse_csv(lines);
|
|
int right = 0, wrong = 0;
|
|
std::vector<std::string> answers;
|
|
for (int i = 1; i < csv_data.size(); i++)
|
|
{
|
|
const auto &elements = csv_data[i];
|
|
std::string prompt = elements[1];
|
|
prompt += "\n\nA. " + elements[2];
|
|
prompt += "\nB. " + elements[3];
|
|
prompt += "\nC. " + elements[4];
|
|
prompt += "\nD. " + elements[5];
|
|
prompt += "\n\n";
|
|
printf("%s", prompt.c_str());
|
|
printf("## 进度: %d / %lu\n", i, lines.size() - 1);
|
|
std::ostringstream lineOs;
|
|
llm->response(prompt.c_str(), &lineOs);
|
|
auto line = lineOs.str();
|
|
printf("%s", line.c_str());
|
|
answers.push_back(line);
|
|
}
|
|
{
|
|
auto position = filename.rfind("/");
|
|
if (position != std::string::npos)
|
|
{
|
|
filename = filename.substr(position + 1, -1);
|
|
}
|
|
position = filename.find("_val");
|
|
if (position != std::string::npos)
|
|
{
|
|
filename.replace(position, 4, "_res");
|
|
}
|
|
std::cout << "store to " << filename << std::endl;
|
|
}
|
|
std::ofstream ofp(filename);
|
|
ofp << "id,answer" << std::endl;
|
|
for (int i = 0; i < answers.size(); i++)
|
|
{
|
|
auto &answer = answers[i];
|
|
ofp << i << ",\"" << answer << "\"" << std::endl;
|
|
}
|
|
ofp.close();
|
|
return 0;
|
|
}
|
|
|
|
static int eval(Llm *llm, std::string prompt_file, int max_token_number)
|
|
{
|
|
std::cout << "prompt file is " << prompt_file << std::endl;
|
|
std::ifstream prompt_fs(prompt_file);
|
|
std::vector<std::string> prompts;
|
|
std::string prompt;
|
|
// #define LLM_DEMO_ONELINE
|
|
#ifdef LLM_DEMO_ONELINE
|
|
std::ostringstream tempOs;
|
|
tempOs << prompt_fs.rdbuf();
|
|
prompt = tempOs.str();
|
|
prompts = {prompt};
|
|
#else
|
|
while (std::getline(prompt_fs, prompt))
|
|
{
|
|
if (prompt.back() == '\r')
|
|
{
|
|
prompt.pop_back();
|
|
}
|
|
prompts.push_back(prompt);
|
|
}
|
|
#endif
|
|
prompt_fs.close();
|
|
if (prompts.empty())
|
|
{
|
|
return 1;
|
|
}
|
|
// ceval
|
|
if (prompts[0] == "id,question,A,B,C,D,answer")
|
|
{
|
|
return ceval(llm, prompts, prompt_file);
|
|
}
|
|
return benchmark(llm, prompts, max_token_number);
|
|
}
|
|
|
|
void chat(Llm *llm)
|
|
{
|
|
ChatMessages messages;
|
|
messages.emplace_back("system", "You are a helpful assistant.");
|
|
auto context = llm->getContext();
|
|
while (true)
|
|
{
|
|
std::cout << "\nUser: ";
|
|
std::string user_str;
|
|
std::getline(std::cin, user_str);
|
|
if (user_str == "/exit")
|
|
{
|
|
return;
|
|
}
|
|
if (user_str == "/reset")
|
|
{
|
|
llm->reset();
|
|
std::cout << "\nA: reset done." << std::endl;
|
|
continue;
|
|
}
|
|
messages.emplace_back("user", user_str);
|
|
std::cout << "\nA: " << std::flush;
|
|
llm->response(messages);
|
|
auto assistant_str = context->generate_str;
|
|
messages.emplace_back("assistant", assistant_str);
|
|
}
|
|
}
|
|
int main(int argc, const char *argv[])
|
|
{
|
|
if (argc < 2)
|
|
{
|
|
std::cout << "Usage: " << argv[0] << " config.json <prompt.txt>" << std::endl;
|
|
return 0;
|
|
}
|
|
MNN::BackendConfig backendConfig;
|
|
auto executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1);
|
|
MNN::Express::ExecutorScope s(executor);
|
|
|
|
std::string config_path = argv[1];
|
|
std::cout << "config path is " << config_path << std::endl;
|
|
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
|
|
llm->set_config("{\"tmp_path\":\"tmp\"}");
|
|
{
|
|
AUTOTIME;
|
|
llm->load();
|
|
}
|
|
if (true)
|
|
{
|
|
AUTOTIME;
|
|
tuning_prepare(llm.get());
|
|
}
|
|
if (argc < 3)
|
|
{
|
|
chat(llm.get());
|
|
return 0;
|
|
}
|
|
int max_token_number = -1;
|
|
if (argc >= 4)
|
|
{
|
|
std::istringstream os(argv[3]);
|
|
os >> max_token_number;
|
|
}
|
|
if (argc >= 5)
|
|
{
|
|
MNN_PRINT("Set not thinking, only valid for Qwen3\n");
|
|
llm->set_config(R"({
|
|
"jinja": {
|
|
"context": {
|
|
"enable_thinking":false
|
|
}
|
|
}
|
|
})");
|
|
}
|
|
std::string prompt_file = argv[2];
|
|
return eval(llm.get(), prompt_file, max_token_number);
|
|
}
|