f93f9182d1b9418f159228ff55bf7c34e7279a5e
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Imagesqueeze - Image codec. Copyright ©2012-2019, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of version 3 of the GNU Lesser General Public License
6  * or later as published by the Free Software Foundation.
7  */
8
9 package eu.svjatoslav.imagesqueeze.codec;
10
11 /**
12  * Compressed image pixels encoder.
13  */
14
15 import eu.svjatoslav.commons.data.BitOutputStream;
16
17 import java.awt.image.DataBufferByte;
18 import java.awt.image.WritableRaster;
19 import java.io.IOException;
20
21 class ImageEncoder {
22
23     private final Image image;
24     private final Approximator approximator;
25     // ColorStats colorStats = new ColorStats();
26     private final OperatingContext context = new OperatingContext();
27     private final OperatingContext context2 = new OperatingContext();
28     int bitsForY;
29     int bitsForU;
30     int bitsForV;
31     private int width;
32     private int height;
33     private Channel yChannel;
34     private Channel uChannel;
35     private Channel vChannel;
36     private BitOutputStream bitOutputStream;
37
38     public ImageEncoder(final Image image) {
39         approximator = new Approximator();
40
41         // bitOutputStream = outputStream;
42
43         this.image = image;
44
45     }
46
47     public static int byteToInt(final byte input) {
48         int result = input;
49         if (result < 0)
50             result = result + 256;
51         return result;
52     }
53
54     public static int decodeValueFromGivenBits(final int encodedBits,
55                                                final int range, final int bitCount) {
56         final int negativeBit = encodedBits & 1;
57
58         final int remainingBitCount = bitCount - 1;
59
60         if (remainingBitCount == 0) {
61             // no more bits remaining to encode actual value
62
63             if (negativeBit == 0)
64                 return range;
65             else
66                 return -range;
67
68         } else {
69             // still one or more bits left, encode value as precisely as
70             // possible
71
72             final int encodedValue = (encodedBits >>> 1) + 1;
73
74             final int realvalueForThisBitcount = 1 << remainingBitCount;
75
76             // int valueMultiplier = range / realvalueForThisBitcount;
77             int decodedValue = (range * encodedValue)
78                     / realvalueForThisBitcount;
79
80             if (decodedValue > range)
81                 decodedValue = range;
82
83             if (negativeBit == 0)
84                 return decodedValue;
85             else
86                 return -decodedValue;
87
88         }
89     }
90
91     private static int encodeValueIntoGivenBits(int value, final int range,
92                                                 final int bitCount) {
93
94         int negativeBit = 0;
95
96         if (value < 0) {
97             negativeBit = 1;
98             value = -value;
99         }
100
101         final int remainingBitCount = bitCount - 1;
102
103         if (remainingBitCount == 0)
104             return negativeBit;
105         else {
106             // still one or more bits left, encode value as precisely as
107             // possible
108
109             if (value > range)
110                 value = range;
111
112             final int realvalueForThisBitcount = 1 << remainingBitCount;
113             // int valueMultiplier = range / realvalueForThisBitcount;
114             int encodedValue = (value * realvalueForThisBitcount) / range;
115
116             if (encodedValue >= realvalueForThisBitcount)
117                 encodedValue = realvalueForThisBitcount - 1;
118
119             encodedValue = (encodedValue << 1) + negativeBit;
120
121             return encodedValue;
122         }
123     }
124
125     public static void storeIntegerCompressed8(
126             final BitOutputStream outputStream, final int data)
127             throws IOException {
128
129         if (data < 256) {
130             outputStream.storeBits(0, 1);
131             outputStream.storeBits(data, 8);
132         } else {
133             outputStream.storeBits(1, 1);
134             outputStream.storeBits(data, 32);
135         }
136     }
137
138     public void encode(final BitOutputStream bitOutputStream)
139             throws IOException {
140         this.bitOutputStream = bitOutputStream;
141
142         approximator.initialize();
143
144         approximator.save(bitOutputStream);
145
146         width = image.metaData.width;
147         height = image.metaData.height;
148
149         final WritableRaster raster = image.bufferedImage.getRaster();
150         final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
151         final byte[] pixels = dbi.getData();
152
153         if (yChannel == null)
154             yChannel = new Channel(width, height);
155         else
156             yChannel.reset();
157
158         if (uChannel == null)
159             uChannel = new Channel(width, height);
160         else
161             uChannel.reset();
162
163         if (vChannel == null)
164             vChannel = new Channel(width, height);
165         else
166             vChannel.reset();
167
168         // create YUV map out of RGB raster data
169         final Color color = new Color();
170
171         for (int y = 0; y < height; y++)
172             for (int x = 0; x < width; x++) {
173
174                 final int index = (y * width) + x;
175                 final int colorBufferIndex = index * 3;
176
177                 int blue = pixels[colorBufferIndex];
178                 if (blue < 0)
179                     blue = blue + 256;
180
181                 int green = pixels[colorBufferIndex + 1];
182                 if (green < 0)
183                     green = green + 256;
184
185                 int red = pixels[colorBufferIndex + 2];
186                 if (red < 0)
187                     red = red + 256;
188
189                 color.r = red;
190                 color.g = green;
191                 color.b = blue;
192
193                 color.RGB2YUV();
194
195                 yChannel.map[index] = (byte) color.y;
196                 uChannel.map[index] = (byte) color.u;
197                 vChannel.map[index] = (byte) color.v;
198             }
199
200         yChannel.decodedMap[0] = yChannel.map[0];
201         uChannel.decodedMap[0] = uChannel.map[0];
202         vChannel.decodedMap[0] = vChannel.map[0];
203
204         bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
205         bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
206         bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
207
208         // detect initial step
209         int largestDimension;
210         int initialStep = 2;
211         if (width > height)
212             largestDimension = width;
213         else
214             largestDimension = height;
215
216         while (initialStep < largestDimension)
217             initialStep = initialStep * 2;
218
219         rangeGrid(initialStep);
220         rangeRoundGrid(2);
221         saveGrid(initialStep);
222     }
223
224     private void encodeChannel(final Table table, final Channel channel,
225                                final int averageDecodedValue, final int index, final int value,
226                                final int range, final int parentIndex) throws IOException {
227
228         final byte[] decodedRangeMap = channel.decodedRangeMap;
229         final byte[] decodedMap = channel.decodedMap;
230
231         final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
232
233         final int inheritedBitCount = table
234                 .proposeBitcountForRange(inheritedRange);
235
236         if (inheritedBitCount > 0) {
237             int computedRange;
238             computedRange = table.proposeRangeForRange(range, inheritedRange);
239             decodedRangeMap[index] = (byte) computedRange;
240
241             channel.bitCount++;
242             if (computedRange != inheritedRange)
243                 // brightness range shrinked
244                 bitOutputStream.storeBits(1, 1);
245             else
246                 // brightness range stayed the same
247                 bitOutputStream.storeBits(0, 1);
248
249             // encode brightness into available amount of bits
250             final int computedBitCount = table
251                     .proposeBitcountForRange(computedRange);
252
253             if (computedBitCount > 0) {
254
255                 final int differenceToEncode = -(value - averageDecodedValue);
256                 final int bitEncodedDifference = encodeValueIntoGivenBits(
257                         differenceToEncode, computedRange, computedBitCount);
258
259                 channel.bitCount = channel.bitCount + computedBitCount;
260                 bitOutputStream.storeBits(bitEncodedDifference,
261                         computedBitCount);
262
263                 final int decodedDifference = decodeValueFromGivenBits(
264                         bitEncodedDifference, computedRange, computedBitCount);
265                 int decodedValue = averageDecodedValue - decodedDifference;
266                 if (decodedValue > 255)
267                     decodedValue = 255;
268                 if (decodedValue < 0)
269                     decodedValue = 0;
270
271                 decodedMap[index] = (byte) decodedValue;
272             } else
273                 decodedMap[index] = (byte) averageDecodedValue;
274
275         } else {
276             decodedRangeMap[index] = (byte) inheritedRange;
277             decodedMap[index] = (byte) averageDecodedValue;
278         }
279     }
280
281     public void printStatistics() {
282         System.out.println("Y channel:");
283         yChannel.printStatistics();
284
285         System.out.println("U channel:");
286         uChannel.printStatistics();
287
288         System.out.println("V channel:");
289         vChannel.printStatistics();
290     }
291
292     private void rangeGrid(final int step) {
293
294         // gridSquare(step / 2, step / 2, step, pixels);
295
296         rangeGridDiagonal(step / 2, step / 2, step);
297         rangeGridSquare(step / 2, 0, step);
298         rangeGridSquare(0, step / 2, step);
299
300         if (step > 2)
301             rangeGrid(step / 2);
302     }
303
304     private void rangeGridDiagonal(final int offsetX, final int offsetY,
305                                    final int step) {
306         for (int y = offsetY; y < height; y = y + step)
307             for (int x = offsetX; x < width; x = x + step) {
308
309                 final int index = (y * width) + x;
310                 final int halfStep = step / 2;
311
312                 context.initialize(image, yChannel.map, uChannel.map,
313                         vChannel.map);
314
315                 context.measureNeighborEncode(x - halfStep, y - halfStep);
316                 context.measureNeighborEncode(x + halfStep, y - halfStep);
317                 context.measureNeighborEncode(x - halfStep, y + halfStep);
318                 context.measureNeighborEncode(x + halfStep, y + halfStep);
319
320                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
321                 uChannel.rangeMap[index] = (byte) context.getURange(index);
322                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
323             }
324     }
325
326     private void rangeGridSquare(final int offsetX, final int offsetY,
327                                  final int step) {
328         for (int y = offsetY; y < height; y = y + step)
329             for (int x = offsetX; x < width; x = x + step) {
330
331                 final int index = (y * width) + x;
332                 final int halfStep = step / 2;
333
334                 context.initialize(image, yChannel.map, uChannel.map,
335                         vChannel.map);
336
337                 context.measureNeighborEncode(x - halfStep, y);
338                 context.measureNeighborEncode(x + halfStep, y);
339                 context.measureNeighborEncode(x, y - halfStep);
340                 context.measureNeighborEncode(x, y + halfStep);
341
342                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
343                 uChannel.rangeMap[index] = (byte) context.getURange(index);
344                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
345             }
346     }
347
348     private void rangeRoundGrid(final int step) {
349
350         rangeRoundGridDiagonal(step / 2, step / 2, step);
351         rangeRoundGridSquare(step / 2, 0, step);
352         rangeRoundGridSquare(0, step / 2, step);
353
354         if (step < 1024)
355             rangeRoundGrid(step * 2);
356     }
357
358     private void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
359                                         final int step) {
360         for (int y = offsetY; y < height; y = y + step)
361             for (int x = offsetX; x < width; x = x + step) {
362
363                 final int index = (y * width) + x;
364
365                 final int yRange = byteToInt(yChannel.rangeMap[index]);
366                 final int uRange = byteToInt(uChannel.rangeMap[index]);
367                 final int vRange = byteToInt(vChannel.rangeMap[index]);
368
369                 final int halfStep = step / 2;
370
371                 final int parentIndex = ((y - halfStep) * width)
372                         + (x - halfStep);
373
374                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
375
376                 if (parentYRange < yRange) {
377                     parentYRange = yRange;
378                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
379                 }
380
381                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
382
383                 if (parentURange < uRange) {
384                     parentURange = uRange;
385                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
386                 }
387
388                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
389
390                 if (parentVRange < vRange) {
391                     parentVRange = vRange;
392                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
393                 }
394             }
395     }
396
397     private void rangeRoundGridSquare(final int offsetX, final int offsetY,
398                                       final int step) {
399         for (int y = offsetY; y < height; y = y + step)
400             for (int x = offsetX; x < width; x = x + step) {
401
402                 final int index = (y * width) + x;
403
404                 final int yRange = byteToInt(yChannel.rangeMap[index]);
405                 final int uRange = byteToInt(uChannel.rangeMap[index]);
406                 final int vRange = byteToInt(vChannel.rangeMap[index]);
407
408                 final int halfStep = step / 2;
409
410                 int parentIndex;
411                 if (offsetX > 0)
412                     parentIndex = (y * width) + (x - halfStep);
413                 else
414                     parentIndex = ((y - halfStep) * width) + x;
415
416                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
417
418                 if (parentYRange < yRange) {
419                     parentYRange = yRange;
420                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
421                 }
422
423                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
424
425                 if (parentURange < uRange) {
426                     parentURange = uRange;
427                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
428                 }
429
430                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
431
432                 if (parentVRange < vRange) {
433                     parentVRange = vRange;
434                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
435                 }
436
437             }
438     }
439
440     private void saveGrid(final int step) throws IOException {
441
442         saveGridDiagonal(step / 2, step / 2, step);
443         saveGridSquare(step / 2, 0, step);
444         saveGridSquare(0, step / 2, step);
445
446         if (step > 2)
447             saveGrid(step / 2);
448     }
449
450     private void saveGridDiagonal(final int offsetX, final int offsetY,
451                                   final int step) throws IOException {
452         for (int y = offsetY; y < height; y = y + step)
453             for (int x = offsetX; x < width; x = x + step) {
454
455                 final int halfStep = step / 2;
456
457                 context2.initialize(image, yChannel.decodedMap,
458                         uChannel.decodedMap, vChannel.decodedMap);
459                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
460                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
461                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
462                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
463
464                 savePixel(step, offsetX, offsetY, x, y,
465                         context2.colorStats.getAverageY(),
466                         context2.colorStats.getAverageU(),
467                         context2.colorStats.getAverageV());
468
469             }
470     }
471
472     private void saveGridSquare(final int offsetX, final int offsetY,
473                                 final int step) throws IOException {
474         for (int y = offsetY; y < height; y = y + step)
475             for (int x = offsetX; x < width; x = x + step) {
476
477                 final int halfStep = step / 2;
478
479                 context2.initialize(image, yChannel.decodedMap,
480                         uChannel.decodedMap, vChannel.decodedMap);
481                 context2.measureNeighborEncode(x - halfStep, y);
482                 context2.measureNeighborEncode(x + halfStep, y);
483                 context2.measureNeighborEncode(x, y - halfStep);
484                 context2.measureNeighborEncode(x, y + halfStep);
485
486                 savePixel(step, offsetX, offsetY, x, y,
487                         context2.colorStats.getAverageY(),
488                         context2.colorStats.getAverageU(),
489                         context2.colorStats.getAverageV());
490
491             }
492     }
493
494     private void savePixel(final int step, final int offsetX, final int offsetY,
495                            final int x, final int y, final int averageDecodedY,
496                            final int averageDecodedU, final int averageDecodedV)
497             throws IOException {
498
499         final int index = (y * width) + x;
500
501         final int py = byteToInt(yChannel.map[index]);
502         final int pu = byteToInt(uChannel.map[index]);
503         final int pv = byteToInt(vChannel.map[index]);
504
505         final int yRange = byteToInt(yChannel.rangeMap[index]);
506         final int uRange = byteToInt(uChannel.rangeMap[index]);
507         final int vRange = byteToInt(vChannel.rangeMap[index]);
508
509         final int halfStep = step / 2;
510
511         int parentIndex;
512         if (offsetX > 0) {
513             if (offsetY > 0)
514                 // diagonal approach
515                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
516             else
517                 // take left pixel
518                 parentIndex = (y * width) + (x - halfStep);
519         } else
520             // take upper pixel
521             parentIndex = ((y - halfStep) * width) + x;
522
523         encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
524                 py, yRange, parentIndex);
525
526         encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
527                 pu, uRange, parentIndex);
528
529         encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
530                 pv, vRange, parentIndex);
531
532     }
533
534 }