1 package eu.svjatoslav.alyverkko_cli;
3 import eu.svjatoslav.alyverkko_cli.commands.MailQuery;
4 import eu.svjatoslav.alyverkko_cli.model.Model;
7 import java.nio.file.Files;
9 import static eu.svjatoslav.alyverkko_cli.Main.configuration;
10 import static java.lang.String.join;
13 public static final String AI_RESPONSE_MARKER = "ASSISTANT:";
14 private static final String LLAMA_CPP_META_INFO_MARKER = "llm_load_print_meta: ";
16 private final Model model;
17 private final Float temperature;
18 private final String systemPrompt;
19 private final String userPrompt;
23 * Creates a new AI task.
25 public AiTask(MailQuery mailQuery) {
26 this.model = mailQuery.model;
27 this.temperature = configuration.getDefaultTemperature();
28 this.systemPrompt = mailQuery.systemPrompt;
29 this.userPrompt = mailQuery.userPrompt;
32 private String buildAiQuery() {
33 StringBuilder sb = new StringBuilder();
34 sb.append("SYSTEM:\n").append(systemPrompt).append("\n");
36 String filteredUserPrompt = filterParticipantsInUserInput(userPrompt);
37 if (!filteredUserPrompt.startsWith("USER:")) sb.append("USER:\n");
38 sb.append(filteredUserPrompt).append("\n");
40 sb.append(AI_RESPONSE_MARKER);
44 public static String filterParticipantsInUserInput(String input) {
45 StringBuilder result = new StringBuilder();
46 String[] lines = input.split("\n");
47 for (int i = 0; i < lines.length; i++) {
48 String line = lines[i];
49 if (i > 0) result.append("\n");
50 if ("* ASSISTANT:".equals(line)) line = "ASSISTANT:";
51 if ("* USER:".equals(line)) line = "USER:";
54 return result.toString();
57 public static String filterParticipantsInAiResponse(String response) {
58 StringBuilder result = new StringBuilder();
59 String[] lines = response.split("\n");
60 for (int i = 0; i < lines.length; i++) {
61 String line = lines[i];
62 if (i > 0) result.append("\n");
63 if ("ASSISTANT:".equals(line)) line = "* ASSISTANT:";
64 if ("USER:".equals(line)) line = "* USER:";
67 result.append("\n* USER:\n");
68 return result.toString();
72 * Compute the AI task.
73 * @return The result of the AI task.
75 public String runAiQuery() throws InterruptedException, IOException {
77 initializeInputFile(buildAiQuery());
79 ProcessBuilder processBuilder = new ProcessBuilder();
80 processBuilder.command(getCliCommand().split("\\s+")); // Splitting the command string into parts
82 Process process = processBuilder.start();
83 handleErrorThread(process);
84 StringBuilder result = new StringBuilder();
85 Thread outputThread = handleResultThread(process, result);
86 process.waitFor(); // Wait for the main AI computing process to finish
87 outputThread.join(); // Wait for the output thread to finish
88 return filterParticipantsInAiResponse(cleanupAiResponse(result.toString()));
90 deleteTemporaryFile();
95 * Initializes the input file for the AI task.
97 private void initializeInputFile(String aiQuery ) throws IOException {
98 // write AI input to file
99 inputFile = createTemporaryFile();
100 Files.write(inputFile.toPath(), aiQuery.getBytes());
104 * Creates and starts a thread to handle the error stream of an AI inference process.
106 * @param process the process to read the error stream from.
108 private static void handleErrorThread(Process process) {
109 Thread errorThread = new Thread(() -> {
110 try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
112 while ((line = reader.readLine()) != null) handleErrorStreamLine(line);
113 } catch (IOException e) {
114 System.err.println("Error reading error stream: " + e.getMessage());
122 * Handles a single line from the error stream of an AI inference process.
123 * If the line contains meta-info, it is printed to the console.
125 * @param line the line to be handled.
127 private static void handleErrorStreamLine(String line) {
128 if (line.startsWith(LLAMA_CPP_META_INFO_MARKER)) {
129 // Print the meta-info to console
130 System.out.println(line.substring(LLAMA_CPP_META_INFO_MARKER.length()));
134 // Print the error to console
135 Utils.printRedMessageToConsole(line);
139 * Gets the full command to be executed by the AI inference process.
141 * @return the full command to be executed by the AI inference process.
143 private String getCliCommand() {
146 configuration.getLlamaCppExecutablePath().getAbsolutePath(),
147 "--model " + model.filesystemPath,
148 "--threads " + configuration.getThreadCount(),
149 "--threads-batch " + configuration.getBatchThreadCount(),
152 "--temp " + temperature,
153 "--ctx-size " + model.contextSizeTokens,
156 "--repeat_penalty 1.1",
157 "--file " + inputFile);
163 * Creates and starts a thread to handle the result of the AI inference process.
164 * The result is read from the process's input stream and saved in a StringBuilder.
166 * @param process the process to read the result from.
167 * @param result the StringBuilder to save the result in.
168 * @return the thread that handles the result.
170 private static Thread handleResultThread(Process process, StringBuilder result) {
171 Thread outputThread = new Thread(() -> {
172 try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
174 while ((aiResultLine = reader.readLine()) != null) {
175 System.out.print("AI: " + aiResultLine + "\n"); // Display each line as it's being read
176 result.append(aiResultLine).append("\n"); // Save the result
178 } catch (IOException e) {
179 throw new RuntimeException(e);
182 outputThread.start();
187 * Returns the temporary file for the AI to work with.
189 private File createTemporaryFile() throws IOException {
190 File file = Files.createTempFile("ai-inference", ".tmp").toFile();
196 * Cleans up the AI response by removing unnecessary text.
197 * @param result the AI response string to be cleaned up.
198 * @return the cleaned-up AI response.k
200 private String cleanupAiResponse(String result) {
202 // remove text before AI response marker
203 int aIResponseIndex = result.lastIndexOf(AI_RESPONSE_MARKER);
204 if (aIResponseIndex != -1) {
205 result = result.substring(aIResponseIndex + AI_RESPONSE_MARKER.length());
208 // remove text after end of text marker, if it exists
209 if (model.endOfTextMarker != null) {
210 int endOfTextMarkerIndex = result.indexOf(model.endOfTextMarker);
211 if (endOfTextMarkerIndex != -1) {
212 result = result.substring(0, endOfTextMarkerIndex);
215 return result + "\n";
218 private void deleteTemporaryFile() {
219 if (inputFile != null && inputFile.exists()) {
221 Files.delete(inputFile.toPath());
222 } catch (IOException e) {
223 System.err.println("Failed to delete temporary file: " + e.getMessage());