Changed license to CC0
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Image codec. Author: Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
3  * This project is released under Creative Commons Zero (CC0) license.
4  */
5 package eu.svjatoslav.imagesqueeze.codec;
6
7 /**
8  * Compressed image pixels encoder.
9  */
10
11 import eu.svjatoslav.commons.data.BitOutputStream;
12
13 import java.awt.image.DataBufferByte;
14 import java.awt.image.WritableRaster;
15 import java.io.IOException;
16
17 class ImageEncoder {
18
19     private final Image image;
20     private final Approximator approximator;
21     // ColorStats colorStats = new ColorStats();
22     private final OperatingContext context = new OperatingContext();
23     private final OperatingContext context2 = new OperatingContext();
24     int bitsForY;
25     int bitsForU;
26     int bitsForV;
27     private int width;
28     private int height;
29     private Channel yChannel;
30     private Channel uChannel;
31     private Channel vChannel;
32     private BitOutputStream bitOutputStream;
33
34     public ImageEncoder(final Image image) {
35         approximator = new Approximator();
36
37         // bitOutputStream = outputStream;
38
39         this.image = image;
40
41     }
42
43     public static int byteToInt(final byte input) {
44         int result = input;
45         if (result < 0)
46             result = result + 256;
47         return result;
48     }
49
50     public static int decodeValueFromGivenBits(final int encodedBits,
51                                                final int range, final int bitCount) {
52         final int negativeBit = encodedBits & 1;
53
54         final int remainingBitCount = bitCount - 1;
55
56         if (remainingBitCount == 0) {
57             // no more bits remaining to encode actual value
58
59             if (negativeBit == 0)
60                 return range;
61             else
62                 return -range;
63
64         } else {
65             // still one or more bits left, encode value as precisely as
66             // possible
67
68             final int encodedValue = (encodedBits >>> 1) + 1;
69
70             final int realvalueForThisBitcount = 1 << remainingBitCount;
71
72             // int valueMultiplier = range / realvalueForThisBitcount;
73             int decodedValue = (range * encodedValue)
74                     / realvalueForThisBitcount;
75
76             if (decodedValue > range)
77                 decodedValue = range;
78
79             if (negativeBit == 0)
80                 return decodedValue;
81             else
82                 return -decodedValue;
83
84         }
85     }
86
87     private static int encodeValueIntoGivenBits(int value, final int range,
88                                                 final int bitCount) {
89
90         int negativeBit = 0;
91
92         if (value < 0) {
93             negativeBit = 1;
94             value = -value;
95         }
96
97         final int remainingBitCount = bitCount - 1;
98
99         if (remainingBitCount == 0)
100             return negativeBit;
101         else {
102             // still one or more bits left, encode value as precisely as
103             // possible
104
105             if (value > range)
106                 value = range;
107
108             final int realvalueForThisBitcount = 1 << remainingBitCount;
109             // int valueMultiplier = range / realvalueForThisBitcount;
110             int encodedValue = (value * realvalueForThisBitcount) / range;
111
112             if (encodedValue >= realvalueForThisBitcount)
113                 encodedValue = realvalueForThisBitcount - 1;
114
115             encodedValue = (encodedValue << 1) + negativeBit;
116
117             return encodedValue;
118         }
119     }
120
121     public static void storeIntegerCompressed8(
122             final BitOutputStream outputStream, final int data)
123             throws IOException {
124
125         if (data < 256) {
126             outputStream.storeBits(0, 1);
127             outputStream.storeBits(data, 8);
128         } else {
129             outputStream.storeBits(1, 1);
130             outputStream.storeBits(data, 32);
131         }
132     }
133
134     public void encode(final BitOutputStream bitOutputStream)
135             throws IOException {
136         this.bitOutputStream = bitOutputStream;
137
138         approximator.initialize();
139
140         approximator.save(bitOutputStream);
141
142         width = image.metaData.width;
143         height = image.metaData.height;
144
145         final WritableRaster raster = image.bufferedImage.getRaster();
146         final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
147         final byte[] pixels = dbi.getData();
148
149         if (yChannel == null)
150             yChannel = new Channel(width, height);
151         else
152             yChannel.reset();
153
154         if (uChannel == null)
155             uChannel = new Channel(width, height);
156         else
157             uChannel.reset();
158
159         if (vChannel == null)
160             vChannel = new Channel(width, height);
161         else
162             vChannel.reset();
163
164         // create YUV map out of RGB raster data
165         final Color color = new Color();
166
167         for (int y = 0; y < height; y++)
168             for (int x = 0; x < width; x++) {
169
170                 final int index = (y * width) + x;
171                 final int colorBufferIndex = index * 3;
172
173                 int blue = pixels[colorBufferIndex];
174                 if (blue < 0)
175                     blue = blue + 256;
176
177                 int green = pixels[colorBufferIndex + 1];
178                 if (green < 0)
179                     green = green + 256;
180
181                 int red = pixels[colorBufferIndex + 2];
182                 if (red < 0)
183                     red = red + 256;
184
185                 color.r = red;
186                 color.g = green;
187                 color.b = blue;
188
189                 color.RGB2YUV();
190
191                 yChannel.map[index] = (byte) color.y;
192                 uChannel.map[index] = (byte) color.u;
193                 vChannel.map[index] = (byte) color.v;
194             }
195
196         yChannel.decodedMap[0] = yChannel.map[0];
197         uChannel.decodedMap[0] = uChannel.map[0];
198         vChannel.decodedMap[0] = vChannel.map[0];
199
200         bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
201         bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
202         bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
203
204         // detect initial step
205         int largestDimension;
206         int initialStep = 2;
207         if (width > height)
208             largestDimension = width;
209         else
210             largestDimension = height;
211
212         while (initialStep < largestDimension)
213             initialStep = initialStep * 2;
214
215         rangeGrid(initialStep);
216         rangeRoundGrid(2);
217         saveGrid(initialStep);
218     }
219
220     private void encodeChannel(final Table table, final Channel channel,
221                                final int averageDecodedValue, final int index, final int value,
222                                final int range, final int parentIndex) throws IOException {
223
224         final byte[] decodedRangeMap = channel.decodedRangeMap;
225         final byte[] decodedMap = channel.decodedMap;
226
227         final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
228
229         final int inheritedBitCount = table
230                 .proposeBitcountForRange(inheritedRange);
231
232         if (inheritedBitCount > 0) {
233             int computedRange;
234             computedRange = table.proposeRangeForRange(range, inheritedRange);
235             decodedRangeMap[index] = (byte) computedRange;
236
237             channel.bitCount++;
238             if (computedRange != inheritedRange)
239                 // brightness range shrinked
240                 bitOutputStream.storeBits(1, 1);
241             else
242                 // brightness range stayed the same
243                 bitOutputStream.storeBits(0, 1);
244
245             // encode brightness into available amount of bits
246             final int computedBitCount = table
247                     .proposeBitcountForRange(computedRange);
248
249             if (computedBitCount > 0) {
250
251                 final int differenceToEncode = -(value - averageDecodedValue);
252                 final int bitEncodedDifference = encodeValueIntoGivenBits(
253                         differenceToEncode, computedRange, computedBitCount);
254
255                 channel.bitCount = channel.bitCount + computedBitCount;
256                 bitOutputStream.storeBits(bitEncodedDifference,
257                         computedBitCount);
258
259                 final int decodedDifference = decodeValueFromGivenBits(
260                         bitEncodedDifference, computedRange, computedBitCount);
261                 int decodedValue = averageDecodedValue - decodedDifference;
262                 if (decodedValue > 255)
263                     decodedValue = 255;
264                 if (decodedValue < 0)
265                     decodedValue = 0;
266
267                 decodedMap[index] = (byte) decodedValue;
268             } else
269                 decodedMap[index] = (byte) averageDecodedValue;
270
271         } else {
272             decodedRangeMap[index] = (byte) inheritedRange;
273             decodedMap[index] = (byte) averageDecodedValue;
274         }
275     }
276
277     public void printStatistics() {
278         System.out.println("Y channel:");
279         yChannel.printStatistics();
280
281         System.out.println("U channel:");
282         uChannel.printStatistics();
283
284         System.out.println("V channel:");
285         vChannel.printStatistics();
286     }
287
288     private void rangeGrid(final int step) {
289
290         // gridSquare(step / 2, step / 2, step, pixels);
291
292         rangeGridDiagonal(step / 2, step / 2, step);
293         rangeGridSquare(step / 2, 0, step);
294         rangeGridSquare(0, step / 2, step);
295
296         if (step > 2)
297             rangeGrid(step / 2);
298     }
299
300     private void rangeGridDiagonal(final int offsetX, final int offsetY,
301                                    final int step) {
302         for (int y = offsetY; y < height; y = y + step)
303             for (int x = offsetX; x < width; x = x + step) {
304
305                 final int index = (y * width) + x;
306                 final int halfStep = step / 2;
307
308                 context.initialize(image, yChannel.map, uChannel.map,
309                         vChannel.map);
310
311                 context.measureNeighborEncode(x - halfStep, y - halfStep);
312                 context.measureNeighborEncode(x + halfStep, y - halfStep);
313                 context.measureNeighborEncode(x - halfStep, y + halfStep);
314                 context.measureNeighborEncode(x + halfStep, y + halfStep);
315
316                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
317                 uChannel.rangeMap[index] = (byte) context.getURange(index);
318                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
319             }
320     }
321
322     private void rangeGridSquare(final int offsetX, final int offsetY,
323                                  final int step) {
324         for (int y = offsetY; y < height; y = y + step)
325             for (int x = offsetX; x < width; x = x + step) {
326
327                 final int index = (y * width) + x;
328                 final int halfStep = step / 2;
329
330                 context.initialize(image, yChannel.map, uChannel.map,
331                         vChannel.map);
332
333                 context.measureNeighborEncode(x - halfStep, y);
334                 context.measureNeighborEncode(x + halfStep, y);
335                 context.measureNeighborEncode(x, y - halfStep);
336                 context.measureNeighborEncode(x, y + halfStep);
337
338                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
339                 uChannel.rangeMap[index] = (byte) context.getURange(index);
340                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
341             }
342     }
343
344     private void rangeRoundGrid(final int step) {
345
346         rangeRoundGridDiagonal(step / 2, step / 2, step);
347         rangeRoundGridSquare(step / 2, 0, step);
348         rangeRoundGridSquare(0, step / 2, step);
349
350         if (step < 1024)
351             rangeRoundGrid(step * 2);
352     }
353
354     private void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
355                                         final int step) {
356         for (int y = offsetY; y < height; y = y + step)
357             for (int x = offsetX; x < width; x = x + step) {
358
359                 final int index = (y * width) + x;
360
361                 final int yRange = byteToInt(yChannel.rangeMap[index]);
362                 final int uRange = byteToInt(uChannel.rangeMap[index]);
363                 final int vRange = byteToInt(vChannel.rangeMap[index]);
364
365                 final int halfStep = step / 2;
366
367                 final int parentIndex = ((y - halfStep) * width)
368                         + (x - halfStep);
369
370                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
371
372                 if (parentYRange < yRange) {
373                     parentYRange = yRange;
374                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
375                 }
376
377                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
378
379                 if (parentURange < uRange) {
380                     parentURange = uRange;
381                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
382                 }
383
384                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
385
386                 if (parentVRange < vRange) {
387                     parentVRange = vRange;
388                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
389                 }
390             }
391     }
392
393     private void rangeRoundGridSquare(final int offsetX, final int offsetY,
394                                       final int step) {
395         for (int y = offsetY; y < height; y = y + step)
396             for (int x = offsetX; x < width; x = x + step) {
397
398                 final int index = (y * width) + x;
399
400                 final int yRange = byteToInt(yChannel.rangeMap[index]);
401                 final int uRange = byteToInt(uChannel.rangeMap[index]);
402                 final int vRange = byteToInt(vChannel.rangeMap[index]);
403
404                 final int halfStep = step / 2;
405
406                 int parentIndex;
407                 if (offsetX > 0)
408                     parentIndex = (y * width) + (x - halfStep);
409                 else
410                     parentIndex = ((y - halfStep) * width) + x;
411
412                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
413
414                 if (parentYRange < yRange) {
415                     parentYRange = yRange;
416                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
417                 }
418
419                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
420
421                 if (parentURange < uRange) {
422                     parentURange = uRange;
423                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
424                 }
425
426                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
427
428                 if (parentVRange < vRange) {
429                     parentVRange = vRange;
430                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
431                 }
432
433             }
434     }
435
436     private void saveGrid(final int step) throws IOException {
437
438         saveGridDiagonal(step / 2, step / 2, step);
439         saveGridSquare(step / 2, 0, step);
440         saveGridSquare(0, step / 2, step);
441
442         if (step > 2)
443             saveGrid(step / 2);
444     }
445
446     private void saveGridDiagonal(final int offsetX, final int offsetY,
447                                   final int step) throws IOException {
448         for (int y = offsetY; y < height; y = y + step)
449             for (int x = offsetX; x < width; x = x + step) {
450
451                 final int halfStep = step / 2;
452
453                 context2.initialize(image, yChannel.decodedMap,
454                         uChannel.decodedMap, vChannel.decodedMap);
455                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
456                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
457                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
458                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
459
460                 savePixel(step, offsetX, offsetY, x, y,
461                         context2.colorStats.getAverageY(),
462                         context2.colorStats.getAverageU(),
463                         context2.colorStats.getAverageV());
464
465             }
466     }
467
468     private void saveGridSquare(final int offsetX, final int offsetY,
469                                 final int step) throws IOException {
470         for (int y = offsetY; y < height; y = y + step)
471             for (int x = offsetX; x < width; x = x + step) {
472
473                 final int halfStep = step / 2;
474
475                 context2.initialize(image, yChannel.decodedMap,
476                         uChannel.decodedMap, vChannel.decodedMap);
477                 context2.measureNeighborEncode(x - halfStep, y);
478                 context2.measureNeighborEncode(x + halfStep, y);
479                 context2.measureNeighborEncode(x, y - halfStep);
480                 context2.measureNeighborEncode(x, y + halfStep);
481
482                 savePixel(step, offsetX, offsetY, x, y,
483                         context2.colorStats.getAverageY(),
484                         context2.colorStats.getAverageU(),
485                         context2.colorStats.getAverageV());
486
487             }
488     }
489
490     private void savePixel(final int step, final int offsetX, final int offsetY,
491                            final int x, final int y, final int averageDecodedY,
492                            final int averageDecodedU, final int averageDecodedV)
493             throws IOException {
494
495         final int index = (y * width) + x;
496
497         final int py = byteToInt(yChannel.map[index]);
498         final int pu = byteToInt(uChannel.map[index]);
499         final int pv = byteToInt(vChannel.map[index]);
500
501         final int yRange = byteToInt(yChannel.rangeMap[index]);
502         final int uRange = byteToInt(uChannel.rangeMap[index]);
503         final int vRange = byteToInt(vChannel.rangeMap[index]);
504
505         final int halfStep = step / 2;
506
507         int parentIndex;
508         if (offsetX > 0) {
509             if (offsetY > 0)
510                 // diagonal approach
511                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
512             else
513                 // take left pixel
514                 parentIndex = (y * width) + (x - halfStep);
515         } else
516             // take upper pixel
517             parentIndex = ((y - halfStep) * width) + x;
518
519         encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
520                 py, yRange, parentIndex);
521
522         encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
523                 pu, uRange, parentIndex);
524
525         encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
526                 pv, vRange, parentIndex);
527
528     }
529
530 }