5b0e8db5f85a28a5b8607c118aa0a6d73091404b
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Imagesqueeze - Image codec optimized for photos.
3  * Copyright (C) 2012, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
4  *
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of version 2 of the GNU General Public License
7  * as published by the Free Software Foundation.
8  */
9
10 package eu.svjatoslav.imagesqueeze.codec;
11
12 /**
13  * Compressed image pixels encoder.
14  */
15
16 import java.awt.image.DataBufferByte;
17 import java.awt.image.WritableRaster;
18 import java.io.IOException;
19
20 import eu.svjatoslav.commons.data.BitOutputStream;
21
22 public class ImageEncoder {
23
24         public static int byteToInt(final byte input) {
25                 int result = input;
26                 if (result < 0)
27                         result = result + 256;
28                 return result;
29         }
30
31         public static int decodeValueFromGivenBits(final int encodedBits,
32                         final int range, final int bitCount) {
33                 final int negativeBit = encodedBits & 1;
34
35                 final int remainingBitCount = bitCount - 1;
36
37                 if (remainingBitCount == 0) {
38                         // no more bits remaining to encode actual value
39
40                         if (negativeBit == 0)
41                                 return range;
42                         else
43                                 return -range;
44
45                 } else {
46                         // still one or more bits left, encode value as precisely as
47                         // possible
48
49                         final int encodedValue = (encodedBits >>> 1) + 1;
50
51                         final int realvalueForThisBitcount = 1 << remainingBitCount;
52
53                         // int valueMultiplier = range / realvalueForThisBitcount;
54                         int decodedValue = (range * encodedValue)
55                                         / realvalueForThisBitcount;
56
57                         if (decodedValue > range)
58                                 decodedValue = range;
59
60                         if (negativeBit == 0)
61                                 return decodedValue;
62                         else
63                                 return -decodedValue;
64
65                 }
66         }
67
68         public static int encodeValueIntoGivenBits(int value, final int range,
69                         final int bitCount) {
70
71                 int negativeBit = 0;
72
73                 if (value < 0) {
74                         negativeBit = 1;
75                         value = -value;
76                 }
77
78                 final int remainingBitCount = bitCount - 1;
79
80                 if (remainingBitCount == 0)
81                         return negativeBit;
82                 else {
83                         // still one or more bits left, encode value as precisely as
84                         // possible
85
86                         if (value > range)
87                                 value = range;
88
89                         final int realvalueForThisBitcount = 1 << remainingBitCount;
90                         // int valueMultiplier = range / realvalueForThisBitcount;
91                         int encodedValue = (value * realvalueForThisBitcount) / range;
92
93                         if (encodedValue >= realvalueForThisBitcount)
94                                 encodedValue = realvalueForThisBitcount - 1;
95
96                         encodedValue = (encodedValue << 1) + negativeBit;
97
98                         return encodedValue;
99                 }
100         }
101
102         public static void storeIntegerCompressed8(
103                         final BitOutputStream outputStream, final int data)
104                         throws IOException {
105
106                 if (data < 256) {
107                         outputStream.storeBits(0, 1);
108                         outputStream.storeBits(data, 8);
109                 } else {
110                         outputStream.storeBits(1, 1);
111                         outputStream.storeBits(data, 32);
112                 }
113         }
114
115         Image image;
116         int width, height;
117
118         Channel yChannel;
119
120         Channel uChannel;
121         Channel vChannel;
122         Approximator approximator;
123
124         int bitsForY;
125         int bitsForU;
126
127         int bitsForV;
128
129         // ColorStats colorStats = new ColorStats();
130         OperatingContext context = new OperatingContext();
131
132         OperatingContext context2 = new OperatingContext();
133
134         BitOutputStream bitOutputStream;
135
136         public ImageEncoder(final Image image) {
137                 approximator = new Approximator();
138
139                 // bitOutputStream = outputStream;
140
141                 this.image = image;
142
143         }
144
145         public void encode(final BitOutputStream bitOutputStream)
146                         throws IOException {
147                 this.bitOutputStream = bitOutputStream;
148
149                 approximator.initialize();
150
151                 approximator.save(bitOutputStream);
152
153                 width = image.metaData.width;
154                 height = image.metaData.height;
155
156                 final WritableRaster raster = image.bufferedImage.getRaster();
157                 final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
158                 final byte[] pixels = dbi.getData();
159
160                 if (yChannel == null)
161                         yChannel = new Channel(width, height);
162                 else
163                         yChannel.reset();
164
165                 if (uChannel == null)
166                         uChannel = new Channel(width, height);
167                 else
168                         uChannel.reset();
169
170                 if (vChannel == null)
171                         vChannel = new Channel(width, height);
172                 else
173                         vChannel.reset();
174
175                 // create YUV map out of RGB raster data
176                 final Color color = new Color();
177
178                 for (int y = 0; y < height; y++)
179                         for (int x = 0; x < width; x++) {
180
181                                 final int index = (y * width) + x;
182                                 final int colorBufferIndex = index * 3;
183
184                                 int blue = pixels[colorBufferIndex];
185                                 if (blue < 0)
186                                         blue = blue + 256;
187
188                                 int green = pixels[colorBufferIndex + 1];
189                                 if (green < 0)
190                                         green = green + 256;
191
192                                 int red = pixels[colorBufferIndex + 2];
193                                 if (red < 0)
194                                         red = red + 256;
195
196                                 color.r = red;
197                                 color.g = green;
198                                 color.b = blue;
199
200                                 color.RGB2YUV();
201
202                                 yChannel.map[index] = (byte) color.y;
203                                 uChannel.map[index] = (byte) color.u;
204                                 vChannel.map[index] = (byte) color.v;
205                         }
206
207                 yChannel.decodedMap[0] = yChannel.map[0];
208                 uChannel.decodedMap[0] = uChannel.map[0];
209                 vChannel.decodedMap[0] = vChannel.map[0];
210
211                 bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
212                 bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
213                 bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
214
215                 // detect initial step
216                 int largestDimension;
217                 int initialStep = 2;
218                 if (width > height)
219                         largestDimension = width;
220                 else
221                         largestDimension = height;
222
223                 while (initialStep < largestDimension)
224                         initialStep = initialStep * 2;
225
226                 rangeGrid(initialStep);
227                 rangeRoundGrid(2);
228                 saveGrid(initialStep);
229         }
230
231         private void encodeChannel(final Table table, final Channel channel,
232                         final int averageDecodedValue, final int index, final int value,
233                         final int range, final int parentIndex) throws IOException {
234
235                 final byte[] decodedRangeMap = channel.decodedRangeMap;
236                 final byte[] decodedMap = channel.decodedMap;
237
238                 final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
239
240                 final int inheritedBitCount = table
241                                 .proposeBitcountForRange(inheritedRange);
242
243                 if (inheritedBitCount > 0) {
244                         int computedRange;
245                         computedRange = table.proposeRangeForRange(range, inheritedRange);
246                         decodedRangeMap[index] = (byte) computedRange;
247
248                         channel.bitCount++;
249                         if (computedRange != inheritedRange)
250                                 // brightness range shrinked
251                                 bitOutputStream.storeBits(1, 1);
252                         else
253                                 // brightness range stayed the same
254                                 bitOutputStream.storeBits(0, 1);
255
256                         // encode brightness into available amount of bits
257                         final int computedBitCount = table
258                                         .proposeBitcountForRange(computedRange);
259
260                         if (computedBitCount > 0) {
261
262                                 final int differenceToEncode = -(value - averageDecodedValue);
263                                 final int bitEncodedDifference = encodeValueIntoGivenBits(
264                                                 differenceToEncode, computedRange, computedBitCount);
265
266                                 channel.bitCount = channel.bitCount + computedBitCount;
267                                 bitOutputStream.storeBits(bitEncodedDifference,
268                                                 computedBitCount);
269
270                                 final int decodedDifference = decodeValueFromGivenBits(
271                                                 bitEncodedDifference, computedRange, computedBitCount);
272                                 int decodedValue = averageDecodedValue - decodedDifference;
273                                 if (decodedValue > 255)
274                                         decodedValue = 255;
275                                 if (decodedValue < 0)
276                                         decodedValue = 0;
277
278                                 decodedMap[index] = (byte) decodedValue;
279                         } else
280                                 decodedMap[index] = (byte) averageDecodedValue;
281
282                 } else {
283                         decodedRangeMap[index] = (byte) inheritedRange;
284                         decodedMap[index] = (byte) averageDecodedValue;
285                 }
286         }
287
288         public void printStatistics() {
289                 System.out.println("Y channel:");
290                 yChannel.printStatistics();
291
292                 System.out.println("U channel:");
293                 uChannel.printStatistics();
294
295                 System.out.println("V channel:");
296                 vChannel.printStatistics();
297         }
298
299         public void rangeGrid(final int step) {
300
301                 // gridSquare(step / 2, step / 2, step, pixels);
302
303                 rangeGridDiagonal(step / 2, step / 2, step);
304                 rangeGridSquare(step / 2, 0, step);
305                 rangeGridSquare(0, step / 2, step);
306
307                 if (step > 2)
308                         rangeGrid(step / 2);
309         }
310
311         public void rangeGridDiagonal(final int offsetX, final int offsetY,
312                         final int step) {
313                 for (int y = offsetY; y < height; y = y + step)
314                         for (int x = offsetX; x < width; x = x + step) {
315
316                                 final int index = (y * width) + x;
317                                 final int halfStep = step / 2;
318
319                                 context.initialize(image, yChannel.map, uChannel.map,
320                                                 vChannel.map);
321
322                                 context.measureNeighborEncode(x - halfStep, y - halfStep);
323                                 context.measureNeighborEncode(x + halfStep, y - halfStep);
324                                 context.measureNeighborEncode(x - halfStep, y + halfStep);
325                                 context.measureNeighborEncode(x + halfStep, y + halfStep);
326
327                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
328                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
329                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
330                         }
331         }
332
333         public void rangeGridSquare(final int offsetX, final int offsetY,
334                         final int step) {
335                 for (int y = offsetY; y < height; y = y + step)
336                         for (int x = offsetX; x < width; x = x + step) {
337
338                                 final int index = (y * width) + x;
339                                 final int halfStep = step / 2;
340
341                                 context.initialize(image, yChannel.map, uChannel.map,
342                                                 vChannel.map);
343
344                                 context.measureNeighborEncode(x - halfStep, y);
345                                 context.measureNeighborEncode(x + halfStep, y);
346                                 context.measureNeighborEncode(x, y - halfStep);
347                                 context.measureNeighborEncode(x, y + halfStep);
348
349                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
350                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
351                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
352                         }
353         }
354
355         public void rangeRoundGrid(final int step) {
356
357                 rangeRoundGridDiagonal(step / 2, step / 2, step);
358                 rangeRoundGridSquare(step / 2, 0, step);
359                 rangeRoundGridSquare(0, step / 2, step);
360
361                 if (step < 1024)
362                         rangeRoundGrid(step * 2);
363         }
364
365         public void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
366                         final int step) {
367                 for (int y = offsetY; y < height; y = y + step)
368                         for (int x = offsetX; x < width; x = x + step) {
369
370                                 final int index = (y * width) + x;
371
372                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
373                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
374                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
375
376                                 final int halfStep = step / 2;
377
378                                 final int parentIndex = ((y - halfStep) * width)
379                                                 + (x - halfStep);
380
381                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
382
383                                 if (parentYRange < yRange) {
384                                         parentYRange = yRange;
385                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
386                                 }
387
388                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
389
390                                 if (parentURange < uRange) {
391                                         parentURange = uRange;
392                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
393                                 }
394
395                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
396
397                                 if (parentVRange < vRange) {
398                                         parentVRange = vRange;
399                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
400                                 }
401                         }
402         }
403
404         public void rangeRoundGridSquare(final int offsetX, final int offsetY,
405                         final int step) {
406                 for (int y = offsetY; y < height; y = y + step)
407                         for (int x = offsetX; x < width; x = x + step) {
408
409                                 final int index = (y * width) + x;
410
411                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
412                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
413                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
414
415                                 final int halfStep = step / 2;
416
417                                 int parentIndex;
418                                 if (offsetX > 0)
419                                         parentIndex = (y * width) + (x - halfStep);
420                                 else
421                                         parentIndex = ((y - halfStep) * width) + x;
422
423                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
424
425                                 if (parentYRange < yRange) {
426                                         parentYRange = yRange;
427                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
428                                 }
429
430                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
431
432                                 if (parentURange < uRange) {
433                                         parentURange = uRange;
434                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
435                                 }
436
437                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
438
439                                 if (parentVRange < vRange) {
440                                         parentVRange = vRange;
441                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
442                                 }
443
444                         }
445         }
446
447         public void saveGrid(final int step) throws IOException {
448
449                 saveGridDiagonal(step / 2, step / 2, step);
450                 saveGridSquare(step / 2, 0, step);
451                 saveGridSquare(0, step / 2, step);
452
453                 if (step > 2)
454                         saveGrid(step / 2);
455         }
456
457         public void saveGridDiagonal(final int offsetX, final int offsetY,
458                         final int step) throws IOException {
459                 for (int y = offsetY; y < height; y = y + step)
460                         for (int x = offsetX; x < width; x = x + step) {
461
462                                 final int halfStep = step / 2;
463
464                                 context2.initialize(image, yChannel.decodedMap,
465                                                 uChannel.decodedMap, vChannel.decodedMap);
466                                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
467                                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
468                                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
469                                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
470
471                                 savePixel(step, offsetX, offsetY, x, y,
472                                                 context2.colorStats.getAverageY(),
473                                                 context2.colorStats.getAverageU(),
474                                                 context2.colorStats.getAverageV());
475
476                         }
477         }
478
479         public void saveGridSquare(final int offsetX, final int offsetY,
480                         final int step) throws IOException {
481                 for (int y = offsetY; y < height; y = y + step)
482                         for (int x = offsetX; x < width; x = x + step) {
483
484                                 final int halfStep = step / 2;
485
486                                 context2.initialize(image, yChannel.decodedMap,
487                                                 uChannel.decodedMap, vChannel.decodedMap);
488                                 context2.measureNeighborEncode(x - halfStep, y);
489                                 context2.measureNeighborEncode(x + halfStep, y);
490                                 context2.measureNeighborEncode(x, y - halfStep);
491                                 context2.measureNeighborEncode(x, y + halfStep);
492
493                                 savePixel(step, offsetX, offsetY, x, y,
494                                                 context2.colorStats.getAverageY(),
495                                                 context2.colorStats.getAverageU(),
496                                                 context2.colorStats.getAverageV());
497
498                         }
499         }
500
501         public void savePixel(final int step, final int offsetX, final int offsetY,
502                         final int x, final int y, final int averageDecodedY,
503                         final int averageDecodedU, final int averageDecodedV)
504                         throws IOException {
505
506                 final int index = (y * width) + x;
507
508                 final int py = byteToInt(yChannel.map[index]);
509                 final int pu = byteToInt(uChannel.map[index]);
510                 final int pv = byteToInt(vChannel.map[index]);
511
512                 final int yRange = byteToInt(yChannel.rangeMap[index]);
513                 final int uRange = byteToInt(uChannel.rangeMap[index]);
514                 final int vRange = byteToInt(vChannel.rangeMap[index]);
515
516                 final int halfStep = step / 2;
517
518                 int parentIndex;
519                 if (offsetX > 0) {
520                         if (offsetY > 0)
521                                 // diagonal approach
522                                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
523                         else
524                                 // take left pixel
525                                 parentIndex = (y * width) + (x - halfStep);
526                 } else
527                         // take upper pixel
528                         parentIndex = ((y - halfStep) * width) + x;
529
530                 encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
531                                 py, yRange, parentIndex);
532
533                 encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
534                                 pu, uRange, parentIndex);
535
536                 encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
537                                 pv, vRange, parentIndex);
538
539         }
540
541 }