Improve selftest
[alyverkko-cli.git] / src / main / java / eu / svjatoslav / alyverkko_cli / AiTask.java
1 package eu.svjatoslav.alyverkko_cli;
2
3 import eu.svjatoslav.alyverkko_cli.commands.MailQuery;
4 import eu.svjatoslav.alyverkko_cli.model.Model;
5
6 import java.io.*;
7 import java.nio.file.Files;
8
9 import static eu.svjatoslav.alyverkko_cli.Main.configuration;
10 import static java.lang.String.join;
11
12 public class AiTask {
13     public static final String AI_RESPONSE_MARKER = "ASSISTANT:";
14     private static final String LLAMA_CPP_META_INFO_MARKER = "llm_load_print_meta: ";
15
16     private final Model model;
17     private final Float temperature;
18     private final String systemPrompt;
19     private final String userPrompt;
20     File inputFile;
21
22     /**
23      * Creates a new AI task.
24      */
25     public AiTask(MailQuery mailQuery) {
26         this.model = mailQuery.model;
27         this.temperature = configuration.getDefaultTemperature();
28         this.systemPrompt = mailQuery.systemPrompt;
29         this.userPrompt  = mailQuery.userPrompt;
30     }
31
32     private String buildAiQuery() {
33         StringBuilder sb = new StringBuilder();
34         sb.append("SYSTEM:\n").append(systemPrompt).append("\n");
35
36         String filteredUserPrompt = filterParticipantsInUserInput(userPrompt);
37         if (!filteredUserPrompt.startsWith("USER:")) sb.append("USER:\n");
38         sb.append(filteredUserPrompt).append("\n");
39
40         sb.append(AI_RESPONSE_MARKER);
41         return sb.toString();
42     }
43
44     public static String filterParticipantsInUserInput(String input) {
45         StringBuilder result = new StringBuilder();
46         String[] lines = input.split("\n");
47         for (int i = 0; i < lines.length; i++) {
48             String line = lines[i];
49             if (i > 0) result.append("\n");
50             if ("* ASSISTANT:".equals(line)) line = "ASSISTANT:";
51             if ("* USER:".equals(line)) line = "USER:";
52             result.append(line);
53         }
54         return result.toString();
55     }
56
57     public static String filterParticipantsInAiResponse(String response) {
58         StringBuilder result = new StringBuilder();
59         String[] lines = response.split("\n");
60         for (int i = 0; i < lines.length; i++) {
61             String line = lines[i];
62             if (i > 0) result.append("\n");
63             if ("ASSISTANT:".equals(line)) line = "* ASSISTANT:";
64             if ("USER:".equals(line)) line = "* USER:";
65             result.append(line);
66         }
67         result.append("\n* USER:\n");
68         return result.toString();
69     }
70
71     /**
72      * Compute the AI task.
73      * @return The result of the AI task.
74      */
75     public String runAiQuery() throws InterruptedException, IOException {
76         try {
77             initializeInputFile(buildAiQuery());
78
79             ProcessBuilder processBuilder = new ProcessBuilder();
80             processBuilder.command(getCliCommand().split("\\s+")); // Splitting the command string into parts
81
82             Process process = processBuilder.start();
83             handleErrorThread(process);
84             StringBuilder result = new StringBuilder();
85             Thread outputThread = handleResultThread(process, result);
86             process.waitFor(); // Wait for the main AI computing process to finish
87             outputThread.join(); // Wait for the output thread to finish
88             return filterParticipantsInAiResponse(cleanupAiResponse(result.toString()));
89         } finally {
90             deleteTemporaryFile();
91         }
92     }
93
94     /**
95      * Initializes the input file for the AI task.
96      */
97     private void initializeInputFile(String aiQuery ) throws IOException {
98         // write AI input to file
99         inputFile = createTemporaryFile();
100         Files.write(inputFile.toPath(), aiQuery.getBytes());
101     }
102
103     /**
104      * Creates and starts a thread to handle the error stream of an AI inference process.
105      *
106      * @param process the process to read the error stream from.
107      */
108     private static void handleErrorThread(Process process) {
109         Thread errorThread = new Thread(() -> {
110             try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
111                 String line;
112                 while ((line = reader.readLine()) != null) handleErrorStreamLine(line);
113             } catch (IOException e) {
114                 System.err.println("Error reading error stream: " + e.getMessage());
115             }
116         });
117         errorThread.start();
118     }
119
120
121     /**
122      * Handles a single line from the error stream of an AI inference process.
123      * If the line contains meta-info, it is printed to the console.
124      *
125      * @param line the line to be handled.
126      */
127     private static void handleErrorStreamLine(String line) {
128         if (line.startsWith(LLAMA_CPP_META_INFO_MARKER)) {
129             // Print the meta-info to console
130             System.out.println(line.substring(LLAMA_CPP_META_INFO_MARKER.length()));
131             return;
132         }
133
134         // Print the error to console
135         Utils.printRedMessageToConsole(line);
136     }
137
138     /**
139      * Gets the full command to be executed by the AI inference process.
140      *
141      * @return the full command to be executed by the AI inference process.
142      */
143     private String getCliCommand() {
144
145         return join(" ",
146                 configuration.getLlamaCppExecutablePath().getAbsolutePath(),
147                 "--model " + model.filesystemPath,
148                 "--threads " + configuration.getThreadCount(),
149                 "--threads-batch " + configuration.getBatchThreadCount(),
150                 "--mirostat 2",
151                 "--log-disable",
152                 "--temp " + temperature,
153                 "--ctx-size " + model.contextSizeTokens,
154                 "--batch-size 8",
155                 "-n -1",
156                 "--repeat_penalty 1.1",
157                 "--file " + inputFile);
158
159     }
160
161
162     /**
163      * Creates and starts a thread to handle the result of the AI inference process.
164      * The result is read from the process's input stream and saved in a StringBuilder.
165      *
166      * @param process the process to read the result from.
167      * @param result the StringBuilder to save the result in.
168      * @return the thread that handles the result.
169      */
170     private static Thread handleResultThread(Process process, StringBuilder result) {
171         Thread outputThread = new Thread(() -> {
172             try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
173                 String aiResultLine;
174                 while ((aiResultLine = reader.readLine()) != null) {
175                     System.out.print("AI: " + aiResultLine + "\n"); // Display each line as it's being read
176                     result.append(aiResultLine).append("\n"); // Save the result
177                 }
178             } catch (IOException e) {
179                 throw new RuntimeException(e);
180             }
181         });
182         outputThread.start();
183         return outputThread;
184     }
185
186     /**
187      * Returns the temporary file for the AI to work with.
188      */
189     private File createTemporaryFile() throws IOException {
190         File file = Files.createTempFile("ai-inference", ".tmp").toFile();
191         file.deleteOnExit();
192         return file;
193     }
194
195     /**
196      * Cleans up the AI response by removing unnecessary text.
197      * @param result the AI response string to be cleaned up.
198      * @return the cleaned-up AI response.k
199      */
200     private String cleanupAiResponse(String result) {
201
202         // remove text before AI response marker
203         int aIResponseIndex = result.lastIndexOf(AI_RESPONSE_MARKER);
204         if (aIResponseIndex != -1) {
205             result = result.substring(aIResponseIndex + AI_RESPONSE_MARKER.length());
206         }
207
208         // remove text after end of text marker, if it exists
209         if (model.endOfTextMarker != null) {
210             int endOfTextMarkerIndex = result.indexOf(model.endOfTextMarker);
211             if (endOfTextMarkerIndex != -1) {
212                 result = result.substring(0, endOfTextMarkerIndex);
213             }
214         }
215         return result + "\n";
216     }
217
218     private void deleteTemporaryFile() {
219         if (inputFile != null && inputFile.exists()) {
220             try {
221                 Files.delete(inputFile.toPath());
222             } catch (IOException e) {
223                 System.err.println("Failed to delete temporary file: " + e.getMessage());
224             }
225         }
226     }
227
228
229 }