1c5a6082d84bf1ca3eaf105f913ca630937f7cff
[alyverkko-cli.git] / src / main / java / eu / svjatoslav / alyverkko_cli / AiTask.java
1 package eu.svjatoslav.alyverkko_cli;
2
3 import eu.svjatoslav.alyverkko_cli.model.Model;
4
5 import java.io.*;
6 import java.nio.file.Files;
7
8 import static eu.svjatoslav.alyverkko_cli.Main.configuration;
9 import static java.lang.String.join;
10
11 public class AiTask {
12     public static final String AI_RESPONSE_MARKER = "ASSISTANT:";
13     private static final String LLAMA_CPP_META_INFO_MARKER = "llm_load_print_meta: ";
14
15     private final String aiQuery;
16     private final Model model;
17     private final Float temperature;
18     File inputFile;
19
20     /**
21      * Creates a new AI task.
22      *
23      * @param input       Problem statement to be used for the AI task.
24      * @param model       The model to be used for the AI task.
25      * @param temperature The temperature to be used for the AI inference process.
26      */
27     public AiTask(String input, Model model, Float temperature) {
28         this.aiQuery = buildAiQuery(input);
29         this.model = model;
30         this.temperature = temperature == null ? configuration.getDefaultTemperature() : temperature;
31     }
32
33     private String buildAiQuery(String input) {
34         StringBuilder sb = new StringBuilder();
35
36         sb.append("SYSTEM:\nThis conversation involves a user and AI assistant where the AI " +
37                 "is expected to provide not only immediate responses but also detailed and " +
38                 "well-reasoned analysis. The AI should consider all aspects of the query " +
39                 "and deliver insights based on logical deductions and comprehensive understanding." +
40                 "AI assistant should reply using emacs org-mode syntax.\n" +
41                 "Quick recap: *this is bold* [[http://domain.org][This is link]]\n" +
42                 "* Heading level 1\n" +
43                 "** Heading level 2\n" +
44                 "| Col 1 Row 1 | Col 2 Row 1 |\n" +
45                 "| Col 1 Row 2 | Col 2 Row 2 |\n" +
46                 "#+BEGIN_SRC python\n" +
47                 "  print ('Hello, world!')\n" +
48                 "#+END_SRC\n\n");
49
50
51         String filteredInput = filterParticipantsInUserInput(input);
52
53         // if filtered input does not start with "USER:", add it
54         if (!filteredInput.startsWith("USER:")) {
55             filteredInput = "USER:\n" + filteredInput;
56         }
57
58         sb.append(filteredInput).append("\n").append(AI_RESPONSE_MARKER);
59         return sb.toString();
60     }
61
62     public static String filterParticipantsInUserInput(String input) {
63         StringBuilder result = new StringBuilder();
64         String[] lines = input.split("\n");
65         for (int i = 0; i < lines.length; i++) {
66             String line = lines[i];
67             if (i > 0) result.append("\n");
68             if ("* ASSISTANT:".equals(line)) line = "ASSISTANT:";
69             if ("* USER:".equals(line)) line = "USER:";
70             result.append(line);
71         }
72         return result.toString();
73     }
74
75     public static String filterParticipantsInAiResponse(String response) {
76         StringBuilder result = new StringBuilder();
77         String[] lines = response.split("\n");
78         for (int i = 0; i < lines.length; i++) {
79             String line = lines[i];
80             if (i > 0) result.append("\n");
81             if ("ASSISTANT:".equals(line)) line = "* ASSISTANT:";
82             if ("USER:".equals(line)) line = "* USER:";
83             result.append(line);
84         }
85         result.append("\n* USER:\n");
86         return result.toString();
87     }
88
89     /**
90      * Compute the AI task.
91      * @return The result of the AI task.
92      */
93     public String runAiQuery() throws InterruptedException, IOException {
94         try {
95             initializeInputFile();
96
97             ProcessBuilder processBuilder = new ProcessBuilder();
98             processBuilder.command(getCliCommand().split("\\s+")); // Splitting the command string into parts
99
100             Process process = processBuilder.start();
101             handleErrorThread(process);
102             StringBuilder result = new StringBuilder();
103             Thread outputThread = handleResultThread(process, result);
104             process.waitFor(); // Wait for the main AI computing process to finish
105             outputThread.join(); // Wait for the output thread to finish
106             return filterParticipantsInAiResponse(cleanupAiResponse(result.toString()));
107         } finally {
108             deleteTemporaryFile();
109         }
110     }
111
112     /**
113      * Initializes the input file for the AI task.
114      */
115     private void initializeInputFile() throws IOException {
116         // write AI input to file
117         inputFile = createTemporaryFile();
118         Files.write(inputFile.toPath(), aiQuery.getBytes());
119     }
120
121     /**
122      * Creates and starts a thread to handle the error stream of an AI inference process.
123      *
124      * @param process the process to read the error stream from.
125      */
126     private static void handleErrorThread(Process process) {
127         Thread errorThread = new Thread(() -> {
128             try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
129                 String line;
130                 while ((line = reader.readLine()) != null) handleErrorStreamLine(line);
131             } catch (IOException e) {
132                 System.err.println("Error reading error stream: " + e.getMessage());
133             }
134         });
135         errorThread.start();
136     }
137
138
139     /**
140      * Handles a single line from the error stream of an AI inference process.
141      * If the line contains meta-info, it is printed to the console.
142      *
143      * @param line the line to be handled.
144      */
145     private static void handleErrorStreamLine(String line) {
146         if (line.startsWith(LLAMA_CPP_META_INFO_MARKER)) {
147             // Print the meta-info to console
148             System.out.println(line.substring(LLAMA_CPP_META_INFO_MARKER.length()));
149             return;
150         }
151
152         // Print the error to console
153         Utils.printRedMessageToConsole(line);
154     }
155
156     /**
157      * Gets the full command to be executed by the AI inference process.
158      *
159      * @return the full command to be executed by the AI inference process.
160      */
161     private String getCliCommand() {
162
163         return join(" ",
164                 configuration.getLlamaCppExecutablePath().getAbsolutePath(),
165                 "--model " + model.filesystemPath,
166                 "--threads " + configuration.getThreadCount(),
167                 "--threads-batch " + configuration.getBatchThreadCount(),
168                 "--mirostat 2",
169                 "--log-disable",
170                 "--temp " + temperature,
171                 "--ctx-size " + model.contextSizeTokens,
172                 "--batch-size 8",
173                 "-n -1",
174                 "--repeat_penalty 1.1",
175                 "--file " + inputFile);
176
177     }
178
179
180     /**
181      * Creates and starts a thread to handle the result of the AI inference process.
182      * The result is read from the process's input stream and saved in a StringBuilder.
183      *
184      * @param process the process to read the result from.
185      * @param result the StringBuilder to save the result in.
186      * @return the thread that handles the result.
187      */
188     private static Thread handleResultThread(Process process, StringBuilder result) {
189         Thread outputThread = new Thread(() -> {
190             try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
191                 String aiResultLine;
192                 while ((aiResultLine = reader.readLine()) != null) {
193                     System.out.print("AI: " + aiResultLine + "\n"); // Display each line as it's being read
194                     result.append(aiResultLine).append("\n"); // Save the result
195                 }
196             } catch (IOException e) {
197                 throw new RuntimeException(e);
198             }
199         });
200         outputThread.start();
201         return outputThread;
202     }
203
204     /**
205      * Returns the temporary file for the AI to work with.
206      */
207     private File createTemporaryFile() throws IOException {
208         File file = Files.createTempFile("ai-inference", ".tmp").toFile();
209         file.deleteOnExit();
210         return file;
211     }
212
213     /**
214      * Cleans up the AI response by removing unnecessary text.
215      * @param result the AI response string to be cleaned up.
216      * @return the cleaned-up AI response.k
217      */
218     private String cleanupAiResponse(String result) {
219
220         // remove text before AI response marker
221         int aIResponseIndex = result.lastIndexOf(AI_RESPONSE_MARKER);
222         if (aIResponseIndex != -1) {
223             result = result.substring(aIResponseIndex + AI_RESPONSE_MARKER.length());
224         }
225
226         // remove text after end of text marker, if it exists
227         if (model.endOfTextMarker != null) {
228             int endOfTextMarkerIndex = result.indexOf(model.endOfTextMarker);
229             if (endOfTextMarkerIndex != -1) {
230                 result = result.substring(0, endOfTextMarkerIndex);
231             }
232         }
233         return result + "\n";
234     }
235
236     private void deleteTemporaryFile() {
237         if (inputFile != null && inputFile.exists()) {
238             try {
239                 Files.delete(inputFile.toPath());
240             } catch (IOException e) {
241                 System.err.println("Failed to delete temporary file: " + e.getMessage());
242             }
243         }
244     }
245
246     public static String runAiQuery(String problemStatement, Model model, Float temperature) throws IOException, InterruptedException {
247         AiTask ai = new AiTask(problemStatement, model, temperature);
248         return ai.runAiQuery();
249     }
250
251 }