1ee823d80ab0432e25832009942b3978b572d2fa
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Imagesqueeze - Image codec optimized for photos.
3  * Copyright (C) 2012, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
4  *
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of version 2 of the GNU General Public License
7  * as published by the Free Software Foundation.
8  */
9
10 package eu.svjatoslav.imagesqueeze.codec;
11
12 /**
13  * Compressed image pixels encoder.
14  */
15
16 import eu.svjatoslav.commons.data.BitOutputStream;
17
18 import java.awt.image.DataBufferByte;
19 import java.awt.image.WritableRaster;
20 import java.io.IOException;
21
22 class ImageEncoder {
23
24     private final Image image;
25     private final Approximator approximator;
26     // ColorStats colorStats = new ColorStats();
27     private final OperatingContext context = new OperatingContext();
28     private final OperatingContext context2 = new OperatingContext();
29     int bitsForY;
30     int bitsForU;
31     int bitsForV;
32     private int width;
33     private int height;
34     private Channel yChannel;
35     private Channel uChannel;
36     private Channel vChannel;
37     private BitOutputStream bitOutputStream;
38
39     public ImageEncoder(final Image image) {
40         approximator = new Approximator();
41
42         // bitOutputStream = outputStream;
43
44         this.image = image;
45
46     }
47
48     public static int byteToInt(final byte input) {
49         int result = input;
50         if (result < 0)
51             result = result + 256;
52         return result;
53     }
54
55     public static int decodeValueFromGivenBits(final int encodedBits,
56                                                final int range, final int bitCount) {
57         final int negativeBit = encodedBits & 1;
58
59         final int remainingBitCount = bitCount - 1;
60
61         if (remainingBitCount == 0) {
62             // no more bits remaining to encode actual value
63
64             if (negativeBit == 0)
65                 return range;
66             else
67                 return -range;
68
69         } else {
70             // still one or more bits left, encode value as precisely as
71             // possible
72
73             final int encodedValue = (encodedBits >>> 1) + 1;
74
75             final int realvalueForThisBitcount = 1 << remainingBitCount;
76
77             // int valueMultiplier = range / realvalueForThisBitcount;
78             int decodedValue = (range * encodedValue)
79                     / realvalueForThisBitcount;
80
81             if (decodedValue > range)
82                 decodedValue = range;
83
84             if (negativeBit == 0)
85                 return decodedValue;
86             else
87                 return -decodedValue;
88
89         }
90     }
91
92     private static int encodeValueIntoGivenBits(int value, final int range,
93                                                 final int bitCount) {
94
95         int negativeBit = 0;
96
97         if (value < 0) {
98             negativeBit = 1;
99             value = -value;
100         }
101
102         final int remainingBitCount = bitCount - 1;
103
104         if (remainingBitCount == 0)
105             return negativeBit;
106         else {
107             // still one or more bits left, encode value as precisely as
108             // possible
109
110             if (value > range)
111                 value = range;
112
113             final int realvalueForThisBitcount = 1 << remainingBitCount;
114             // int valueMultiplier = range / realvalueForThisBitcount;
115             int encodedValue = (value * realvalueForThisBitcount) / range;
116
117             if (encodedValue >= realvalueForThisBitcount)
118                 encodedValue = realvalueForThisBitcount - 1;
119
120             encodedValue = (encodedValue << 1) + negativeBit;
121
122             return encodedValue;
123         }
124     }
125
126     public static void storeIntegerCompressed8(
127             final BitOutputStream outputStream, final int data)
128             throws IOException {
129
130         if (data < 256) {
131             outputStream.storeBits(0, 1);
132             outputStream.storeBits(data, 8);
133         } else {
134             outputStream.storeBits(1, 1);
135             outputStream.storeBits(data, 32);
136         }
137     }
138
139     public void encode(final BitOutputStream bitOutputStream)
140             throws IOException {
141         this.bitOutputStream = bitOutputStream;
142
143         approximator.initialize();
144
145         approximator.save(bitOutputStream);
146
147         width = image.metaData.width;
148         height = image.metaData.height;
149
150         final WritableRaster raster = image.bufferedImage.getRaster();
151         final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
152         final byte[] pixels = dbi.getData();
153
154         if (yChannel == null)
155             yChannel = new Channel(width, height);
156         else
157             yChannel.reset();
158
159         if (uChannel == null)
160             uChannel = new Channel(width, height);
161         else
162             uChannel.reset();
163
164         if (vChannel == null)
165             vChannel = new Channel(width, height);
166         else
167             vChannel.reset();
168
169         // create YUV map out of RGB raster data
170         final Color color = new Color();
171
172         for (int y = 0; y < height; y++)
173             for (int x = 0; x < width; x++) {
174
175                 final int index = (y * width) + x;
176                 final int colorBufferIndex = index * 3;
177
178                 int blue = pixels[colorBufferIndex];
179                 if (blue < 0)
180                     blue = blue + 256;
181
182                 int green = pixels[colorBufferIndex + 1];
183                 if (green < 0)
184                     green = green + 256;
185
186                 int red = pixels[colorBufferIndex + 2];
187                 if (red < 0)
188                     red = red + 256;
189
190                 color.r = red;
191                 color.g = green;
192                 color.b = blue;
193
194                 color.RGB2YUV();
195
196                 yChannel.map[index] = (byte) color.y;
197                 uChannel.map[index] = (byte) color.u;
198                 vChannel.map[index] = (byte) color.v;
199             }
200
201         yChannel.decodedMap[0] = yChannel.map[0];
202         uChannel.decodedMap[0] = uChannel.map[0];
203         vChannel.decodedMap[0] = vChannel.map[0];
204
205         bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
206         bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
207         bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
208
209         // detect initial step
210         int largestDimension;
211         int initialStep = 2;
212         if (width > height)
213             largestDimension = width;
214         else
215             largestDimension = height;
216
217         while (initialStep < largestDimension)
218             initialStep = initialStep * 2;
219
220         rangeGrid(initialStep);
221         rangeRoundGrid(2);
222         saveGrid(initialStep);
223     }
224
225     private void encodeChannel(final Table table, final Channel channel,
226                                final int averageDecodedValue, final int index, final int value,
227                                final int range, final int parentIndex) throws IOException {
228
229         final byte[] decodedRangeMap = channel.decodedRangeMap;
230         final byte[] decodedMap = channel.decodedMap;
231
232         final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
233
234         final int inheritedBitCount = table
235                 .proposeBitcountForRange(inheritedRange);
236
237         if (inheritedBitCount > 0) {
238             int computedRange;
239             computedRange = table.proposeRangeForRange(range, inheritedRange);
240             decodedRangeMap[index] = (byte) computedRange;
241
242             channel.bitCount++;
243             if (computedRange != inheritedRange)
244                 // brightness range shrinked
245                 bitOutputStream.storeBits(1, 1);
246             else
247                 // brightness range stayed the same
248                 bitOutputStream.storeBits(0, 1);
249
250             // encode brightness into available amount of bits
251             final int computedBitCount = table
252                     .proposeBitcountForRange(computedRange);
253
254             if (computedBitCount > 0) {
255
256                 final int differenceToEncode = -(value - averageDecodedValue);
257                 final int bitEncodedDifference = encodeValueIntoGivenBits(
258                         differenceToEncode, computedRange, computedBitCount);
259
260                 channel.bitCount = channel.bitCount + computedBitCount;
261                 bitOutputStream.storeBits(bitEncodedDifference,
262                         computedBitCount);
263
264                 final int decodedDifference = decodeValueFromGivenBits(
265                         bitEncodedDifference, computedRange, computedBitCount);
266                 int decodedValue = averageDecodedValue - decodedDifference;
267                 if (decodedValue > 255)
268                     decodedValue = 255;
269                 if (decodedValue < 0)
270                     decodedValue = 0;
271
272                 decodedMap[index] = (byte) decodedValue;
273             } else
274                 decodedMap[index] = (byte) averageDecodedValue;
275
276         } else {
277             decodedRangeMap[index] = (byte) inheritedRange;
278             decodedMap[index] = (byte) averageDecodedValue;
279         }
280     }
281
282     public void printStatistics() {
283         System.out.println("Y channel:");
284         yChannel.printStatistics();
285
286         System.out.println("U channel:");
287         uChannel.printStatistics();
288
289         System.out.println("V channel:");
290         vChannel.printStatistics();
291     }
292
293     private void rangeGrid(final int step) {
294
295         // gridSquare(step / 2, step / 2, step, pixels);
296
297         rangeGridDiagonal(step / 2, step / 2, step);
298         rangeGridSquare(step / 2, 0, step);
299         rangeGridSquare(0, step / 2, step);
300
301         if (step > 2)
302             rangeGrid(step / 2);
303     }
304
305     private void rangeGridDiagonal(final int offsetX, final int offsetY,
306                                    final int step) {
307         for (int y = offsetY; y < height; y = y + step)
308             for (int x = offsetX; x < width; x = x + step) {
309
310                 final int index = (y * width) + x;
311                 final int halfStep = step / 2;
312
313                 context.initialize(image, yChannel.map, uChannel.map,
314                         vChannel.map);
315
316                 context.measureNeighborEncode(x - halfStep, y - halfStep);
317                 context.measureNeighborEncode(x + halfStep, y - halfStep);
318                 context.measureNeighborEncode(x - halfStep, y + halfStep);
319                 context.measureNeighborEncode(x + halfStep, y + halfStep);
320
321                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
322                 uChannel.rangeMap[index] = (byte) context.getURange(index);
323                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
324             }
325     }
326
327     private void rangeGridSquare(final int offsetX, final int offsetY,
328                                  final int step) {
329         for (int y = offsetY; y < height; y = y + step)
330             for (int x = offsetX; x < width; x = x + step) {
331
332                 final int index = (y * width) + x;
333                 final int halfStep = step / 2;
334
335                 context.initialize(image, yChannel.map, uChannel.map,
336                         vChannel.map);
337
338                 context.measureNeighborEncode(x - halfStep, y);
339                 context.measureNeighborEncode(x + halfStep, y);
340                 context.measureNeighborEncode(x, y - halfStep);
341                 context.measureNeighborEncode(x, y + halfStep);
342
343                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
344                 uChannel.rangeMap[index] = (byte) context.getURange(index);
345                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
346             }
347     }
348
349     private void rangeRoundGrid(final int step) {
350
351         rangeRoundGridDiagonal(step / 2, step / 2, step);
352         rangeRoundGridSquare(step / 2, 0, step);
353         rangeRoundGridSquare(0, step / 2, step);
354
355         if (step < 1024)
356             rangeRoundGrid(step * 2);
357     }
358
359     private void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
360                                         final int step) {
361         for (int y = offsetY; y < height; y = y + step)
362             for (int x = offsetX; x < width; x = x + step) {
363
364                 final int index = (y * width) + x;
365
366                 final int yRange = byteToInt(yChannel.rangeMap[index]);
367                 final int uRange = byteToInt(uChannel.rangeMap[index]);
368                 final int vRange = byteToInt(vChannel.rangeMap[index]);
369
370                 final int halfStep = step / 2;
371
372                 final int parentIndex = ((y - halfStep) * width)
373                         + (x - halfStep);
374
375                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
376
377                 if (parentYRange < yRange) {
378                     parentYRange = yRange;
379                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
380                 }
381
382                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
383
384                 if (parentURange < uRange) {
385                     parentURange = uRange;
386                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
387                 }
388
389                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
390
391                 if (parentVRange < vRange) {
392                     parentVRange = vRange;
393                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
394                 }
395             }
396     }
397
398     private void rangeRoundGridSquare(final int offsetX, final int offsetY,
399                                       final int step) {
400         for (int y = offsetY; y < height; y = y + step)
401             for (int x = offsetX; x < width; x = x + step) {
402
403                 final int index = (y * width) + x;
404
405                 final int yRange = byteToInt(yChannel.rangeMap[index]);
406                 final int uRange = byteToInt(uChannel.rangeMap[index]);
407                 final int vRange = byteToInt(vChannel.rangeMap[index]);
408
409                 final int halfStep = step / 2;
410
411                 int parentIndex;
412                 if (offsetX > 0)
413                     parentIndex = (y * width) + (x - halfStep);
414                 else
415                     parentIndex = ((y - halfStep) * width) + x;
416
417                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
418
419                 if (parentYRange < yRange) {
420                     parentYRange = yRange;
421                     yChannel.rangeMap[parentIndex] = (byte) parentYRange;
422                 }
423
424                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
425
426                 if (parentURange < uRange) {
427                     parentURange = uRange;
428                     uChannel.rangeMap[parentIndex] = (byte) parentURange;
429                 }
430
431                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
432
433                 if (parentVRange < vRange) {
434                     parentVRange = vRange;
435                     vChannel.rangeMap[parentIndex] = (byte) parentVRange;
436                 }
437
438             }
439     }
440
441     private void saveGrid(final int step) throws IOException {
442
443         saveGridDiagonal(step / 2, step / 2, step);
444         saveGridSquare(step / 2, 0, step);
445         saveGridSquare(0, step / 2, step);
446
447         if (step > 2)
448             saveGrid(step / 2);
449     }
450
451     private void saveGridDiagonal(final int offsetX, final int offsetY,
452                                   final int step) throws IOException {
453         for (int y = offsetY; y < height; y = y + step)
454             for (int x = offsetX; x < width; x = x + step) {
455
456                 final int halfStep = step / 2;
457
458                 context2.initialize(image, yChannel.decodedMap,
459                         uChannel.decodedMap, vChannel.decodedMap);
460                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
461                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
462                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
463                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
464
465                 savePixel(step, offsetX, offsetY, x, y,
466                         context2.colorStats.getAverageY(),
467                         context2.colorStats.getAverageU(),
468                         context2.colorStats.getAverageV());
469
470             }
471     }
472
473     private void saveGridSquare(final int offsetX, final int offsetY,
474                                 final int step) throws IOException {
475         for (int y = offsetY; y < height; y = y + step)
476             for (int x = offsetX; x < width; x = x + step) {
477
478                 final int halfStep = step / 2;
479
480                 context2.initialize(image, yChannel.decodedMap,
481                         uChannel.decodedMap, vChannel.decodedMap);
482                 context2.measureNeighborEncode(x - halfStep, y);
483                 context2.measureNeighborEncode(x + halfStep, y);
484                 context2.measureNeighborEncode(x, y - halfStep);
485                 context2.measureNeighborEncode(x, y + halfStep);
486
487                 savePixel(step, offsetX, offsetY, x, y,
488                         context2.colorStats.getAverageY(),
489                         context2.colorStats.getAverageU(),
490                         context2.colorStats.getAverageV());
491
492             }
493     }
494
495     private void savePixel(final int step, final int offsetX, final int offsetY,
496                            final int x, final int y, final int averageDecodedY,
497                            final int averageDecodedU, final int averageDecodedV)
498             throws IOException {
499
500         final int index = (y * width) + x;
501
502         final int py = byteToInt(yChannel.map[index]);
503         final int pu = byteToInt(uChannel.map[index]);
504         final int pv = byteToInt(vChannel.map[index]);
505
506         final int yRange = byteToInt(yChannel.rangeMap[index]);
507         final int uRange = byteToInt(uChannel.rangeMap[index]);
508         final int vRange = byteToInt(vChannel.rangeMap[index]);
509
510         final int halfStep = step / 2;
511
512         int parentIndex;
513         if (offsetX > 0) {
514             if (offsetY > 0)
515                 // diagonal approach
516                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
517             else
518                 // take left pixel
519                 parentIndex = (y * width) + (x - halfStep);
520         } else
521             // take upper pixel
522             parentIndex = ((y - halfStep) * width) + x;
523
524         encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
525                 py, yRange, parentIndex);
526
527         encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
528                 pu, uRange, parentIndex);
529
530         encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
531                 pv, vRange, parentIndex);
532
533     }
534
535 }