minor fixes
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Imagesqueeze - Image codec optimized for photos.
3  * Copyright (C) 2012, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
4  * 
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of version 2 of the GNU General Public License
7  * as published by the Free Software Foundation.
8  */
9
10 package eu.svjatoslav.imagesqueeze.codec;
11
12 /**
13  * Compressed image pixels encoder.
14  */
15
16 import java.awt.image.DataBufferByte;
17 import java.awt.image.WritableRaster;
18 import java.io.IOException;
19
20 import eu.svjatoslav.commons.data.BitOutputStream;
21
22 public class ImageEncoder {
23
24         Image image;
25         int width, height;
26
27         Channel yChannel;
28         Channel uChannel;
29         Channel vChannel;
30
31         Approximator approximator;
32
33         int bitsForY;
34         int bitsForU;
35         int bitsForV;
36
37         // ColorStats colorStats = new ColorStats();
38         OperatingContext context = new OperatingContext();
39         OperatingContext context2 = new OperatingContext();
40
41         BitOutputStream bitOutputStream;
42
43         public ImageEncoder(final Image image) {
44                 approximator = new Approximator();
45
46                 // bitOutputStream = outputStream;
47
48                 this.image = image;
49
50         }
51
52         public void encode(final BitOutputStream bitOutputStream)
53                         throws IOException {
54                 this.bitOutputStream = bitOutputStream;
55
56                 approximator.initialize();
57
58                 approximator.save(bitOutputStream);
59
60                 width = image.metaData.width;
61                 height = image.metaData.height;
62
63                 final WritableRaster raster = image.bufferedImage.getRaster();
64                 final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
65                 final byte[] pixels = dbi.getData();
66
67                 if (yChannel == null) {
68                         yChannel = new Channel(width, height);
69                 } else {
70                         yChannel.reset();
71                 }
72
73                 if (uChannel == null) {
74                         uChannel = new Channel(width, height);
75                 } else {
76                         uChannel.reset();
77                 }
78
79                 if (vChannel == null) {
80                         vChannel = new Channel(width, height);
81                 } else {
82                         vChannel.reset();
83                 }
84
85                 // create YUV map out of RGB raster data
86                 final Color color = new Color();
87
88                 for (int y = 0; y < height; y++) {
89                         for (int x = 0; x < width; x++) {
90
91                                 final int index = (y * width) + x;
92                                 final int colorBufferIndex = index * 3;
93
94                                 int blue = pixels[colorBufferIndex];
95                                 if (blue < 0) {
96                                         blue = blue + 256;
97                                 }
98
99                                 int green = pixels[colorBufferIndex + 1];
100                                 if (green < 0) {
101                                         green = green + 256;
102                                 }
103
104                                 int red = pixels[colorBufferIndex + 2];
105                                 if (red < 0) {
106                                         red = red + 256;
107                                 }
108
109                                 color.r = red;
110                                 color.g = green;
111                                 color.b = blue;
112
113                                 color.RGB2YUV();
114
115                                 yChannel.map[index] = (byte) color.y;
116                                 uChannel.map[index] = (byte) color.u;
117                                 vChannel.map[index] = (byte) color.v;
118                         }
119                 }
120
121                 yChannel.decodedMap[0] = yChannel.map[0];
122                 uChannel.decodedMap[0] = uChannel.map[0];
123                 vChannel.decodedMap[0] = vChannel.map[0];
124
125                 bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
126                 bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
127                 bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
128
129                 // detect initial step
130                 int largestDimension;
131                 int initialStep = 2;
132                 if (width > height) {
133                         largestDimension = width;
134                 } else {
135                         largestDimension = height;
136                 }
137
138                 while (initialStep < largestDimension) {
139                         initialStep = initialStep * 2;
140                 }
141
142                 rangeGrid(initialStep);
143                 rangeRoundGrid(2);
144                 saveGrid(initialStep);
145         }
146
147         private void encodeChannel(final Table table, final Channel channel,
148                         final int averageDecodedValue, final int index, final int value,
149                         final int range, final int parentIndex) throws IOException {
150
151                 final byte[] decodedRangeMap = channel.decodedRangeMap;
152                 final byte[] decodedMap = channel.decodedMap;
153
154                 final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
155
156                 final int inheritedBitCount = table
157                                 .proposeBitcountForRange(inheritedRange);
158
159                 if (inheritedBitCount > 0) {
160                         int computedRange;
161                         computedRange = table.proposeRangeForRange(range, inheritedRange);
162                         decodedRangeMap[index] = (byte) computedRange;
163
164                         channel.bitCount++;
165                         if (computedRange != inheritedRange) {
166                                 // brightness range shrinked
167                                 bitOutputStream.storeBits(1, 1);
168                         } else {
169                                 // brightness range stayed the same
170                                 bitOutputStream.storeBits(0, 1);
171                         }
172
173                         // encode brightness into available amount of bits
174                         final int computedBitCount = table
175                                         .proposeBitcountForRange(computedRange);
176
177                         if (computedBitCount > 0) {
178
179                                 final int differenceToEncode = -(value - averageDecodedValue);
180                                 final int bitEncodedDifference = encodeValueIntoGivenBits(
181                                                 differenceToEncode, computedRange, computedBitCount);
182
183                                 channel.bitCount = channel.bitCount + computedBitCount;
184                                 bitOutputStream.storeBits(bitEncodedDifference,
185                                                 computedBitCount);
186
187                                 final int decodedDifference = decodeValueFromGivenBits(
188                                                 bitEncodedDifference, computedRange, computedBitCount);
189                                 int decodedValue = averageDecodedValue - decodedDifference;
190                                 if (decodedValue > 255) {
191                                         decodedValue = 255;
192                                 }
193                                 if (decodedValue < 0) {
194                                         decodedValue = 0;
195                                 }
196
197                                 decodedMap[index] = (byte) decodedValue;
198                         } else {
199                                 decodedMap[index] = (byte) averageDecodedValue;
200                         }
201
202                 } else {
203                         decodedRangeMap[index] = (byte) inheritedRange;
204                         decodedMap[index] = (byte) averageDecodedValue;
205                 }
206         }
207
208         public void printStatistics() {
209                 System.out.println("Y channel:");
210                 yChannel.printStatistics();
211
212                 System.out.println("U channel:");
213                 uChannel.printStatistics();
214
215                 System.out.println("V channel:");
216                 vChannel.printStatistics();
217         }
218
219         public void rangeGrid(final int step) {
220
221                 // gridSquare(step / 2, step / 2, step, pixels);
222
223                 rangeGridDiagonal(step / 2, step / 2, step);
224                 rangeGridSquare(step / 2, 0, step);
225                 rangeGridSquare(0, step / 2, step);
226
227                 if (step > 2) {
228                         rangeGrid(step / 2);
229                 }
230         }
231
232         public void rangeGridDiagonal(final int offsetX, final int offsetY,
233                         final int step) {
234                 for (int y = offsetY; y < height; y = y + step) {
235                         for (int x = offsetX; x < width; x = x + step) {
236
237                                 final int index = (y * width) + x;
238                                 final int halfStep = step / 2;
239
240                                 context.initialize(image, yChannel.map, uChannel.map,
241                                                 vChannel.map);
242
243                                 context.measureNeighborEncode(x - halfStep, y - halfStep);
244                                 context.measureNeighborEncode(x + halfStep, y - halfStep);
245                                 context.measureNeighborEncode(x - halfStep, y + halfStep);
246                                 context.measureNeighborEncode(x + halfStep, y + halfStep);
247
248                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
249                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
250                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
251                         }
252                 }
253         }
254
255         public void rangeGridSquare(final int offsetX, final int offsetY,
256                         final int step) {
257                 for (int y = offsetY; y < height; y = y + step) {
258                         for (int x = offsetX; x < width; x = x + step) {
259
260                                 final int index = (y * width) + x;
261                                 final int halfStep = step / 2;
262
263                                 context.initialize(image, yChannel.map, uChannel.map,
264                                                 vChannel.map);
265
266                                 context.measureNeighborEncode(x - halfStep, y);
267                                 context.measureNeighborEncode(x + halfStep, y);
268                                 context.measureNeighborEncode(x, y - halfStep);
269                                 context.measureNeighborEncode(x, y + halfStep);
270
271                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
272                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
273                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
274                         }
275                 }
276         }
277
278         public void rangeRoundGrid(final int step) {
279
280                 rangeRoundGridDiagonal(step / 2, step / 2, step);
281                 rangeRoundGridSquare(step / 2, 0, step);
282                 rangeRoundGridSquare(0, step / 2, step);
283
284                 if (step < 1024) {
285                         rangeRoundGrid(step * 2);
286                 }
287         }
288
289         public void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
290                         final int step) {
291                 for (int y = offsetY; y < height; y = y + step) {
292                         for (int x = offsetX; x < width; x = x + step) {
293
294                                 final int index = (y * width) + x;
295
296                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
297                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
298                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
299
300                                 final int halfStep = step / 2;
301
302                                 final int parentIndex = ((y - halfStep) * width)
303                                                 + (x - halfStep);
304
305                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
306
307                                 if (parentYRange < yRange) {
308                                         parentYRange = yRange;
309                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
310                                 }
311
312                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
313
314                                 if (parentURange < uRange) {
315                                         parentURange = uRange;
316                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
317                                 }
318
319                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
320
321                                 if (parentVRange < vRange) {
322                                         parentVRange = vRange;
323                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
324                                 }
325                         }
326                 }
327         }
328
329         public void rangeRoundGridSquare(final int offsetX, final int offsetY,
330                         final int step) {
331                 for (int y = offsetY; y < height; y = y + step) {
332                         for (int x = offsetX; x < width; x = x + step) {
333
334                                 final int index = (y * width) + x;
335
336                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
337                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
338                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
339
340                                 final int halfStep = step / 2;
341
342                                 int parentIndex;
343                                 if (offsetX > 0) {
344                                         parentIndex = (y * width) + (x - halfStep);
345                                 } else {
346                                         parentIndex = ((y - halfStep) * width) + x;
347                                 }
348
349                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
350
351                                 if (parentYRange < yRange) {
352                                         parentYRange = yRange;
353                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
354                                 }
355
356                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
357
358                                 if (parentURange < uRange) {
359                                         parentURange = uRange;
360                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
361                                 }
362
363                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
364
365                                 if (parentVRange < vRange) {
366                                         parentVRange = vRange;
367                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
368                                 }
369
370                         }
371                 }
372         }
373
374         public void saveGrid(final int step) throws IOException {
375
376                 saveGridDiagonal(step / 2, step / 2, step);
377                 saveGridSquare(step / 2, 0, step);
378                 saveGridSquare(0, step / 2, step);
379
380                 if (step > 2) {
381                         saveGrid(step / 2);
382                 }
383         }
384
385         public void saveGridDiagonal(final int offsetX, final int offsetY,
386                         final int step) throws IOException {
387                 for (int y = offsetY; y < height; y = y + step) {
388                         for (int x = offsetX; x < width; x = x + step) {
389
390                                 final int halfStep = step / 2;
391
392                                 context2.initialize(image, yChannel.decodedMap,
393                                                 uChannel.decodedMap, vChannel.decodedMap);
394                                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
395                                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
396                                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
397                                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
398
399                                 savePixel(step, offsetX, offsetY, x, y,
400                                                 context2.colorStats.getAverageY(),
401                                                 context2.colorStats.getAverageU(),
402                                                 context2.colorStats.getAverageV());
403
404                         }
405                 }
406         }
407
408         public void saveGridSquare(final int offsetX, final int offsetY,
409                         final int step) throws IOException {
410                 for (int y = offsetY; y < height; y = y + step) {
411                         for (int x = offsetX; x < width; x = x + step) {
412
413                                 final int halfStep = step / 2;
414
415                                 context2.initialize(image, yChannel.decodedMap,
416                                                 uChannel.decodedMap, vChannel.decodedMap);
417                                 context2.measureNeighborEncode(x - halfStep, y);
418                                 context2.measureNeighborEncode(x + halfStep, y);
419                                 context2.measureNeighborEncode(x, y - halfStep);
420                                 context2.measureNeighborEncode(x, y + halfStep);
421
422                                 savePixel(step, offsetX, offsetY, x, y,
423                                                 context2.colorStats.getAverageY(),
424                                                 context2.colorStats.getAverageU(),
425                                                 context2.colorStats.getAverageV());
426
427                         }
428                 }
429         }
430
431         public void savePixel(final int step, final int offsetX, final int offsetY,
432                         final int x, final int y, final int averageDecodedY,
433                         final int averageDecodedU, final int averageDecodedV)
434                         throws IOException {
435
436                 final int index = (y * width) + x;
437
438                 final int py = byteToInt(yChannel.map[index]);
439                 final int pu = byteToInt(uChannel.map[index]);
440                 final int pv = byteToInt(vChannel.map[index]);
441
442                 final int yRange = byteToInt(yChannel.rangeMap[index]);
443                 final int uRange = byteToInt(uChannel.rangeMap[index]);
444                 final int vRange = byteToInt(vChannel.rangeMap[index]);
445
446                 final int halfStep = step / 2;
447
448                 int parentIndex;
449                 if (offsetX > 0) {
450                         if (offsetY > 0) {
451                                 // diagonal approach
452                                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
453                         } else {
454                                 // take left pixel
455                                 parentIndex = (y * width) + (x - halfStep);
456                         }
457                 } else {
458                         // take upper pixel
459                         parentIndex = ((y - halfStep) * width) + x;
460                 }
461
462                 encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
463                                 py, yRange, parentIndex);
464
465                 encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
466                                 pu, uRange, parentIndex);
467
468                 encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
469                                 pv, vRange, parentIndex);
470
471         }
472
473         public static int byteToInt(final byte input) {
474                 int result = input;
475                 if (result < 0) {
476                         result = result + 256;
477                 }
478                 return result;
479         }
480
481         public static int decodeValueFromGivenBits(final int encodedBits,
482                         final int range, final int bitCount) {
483                 final int negativeBit = encodedBits & 1;
484
485                 final int remainingBitCount = bitCount - 1;
486
487                 if (remainingBitCount == 0) {
488                         // no more bits remaining to encode actual value
489
490                         if (negativeBit == 0) {
491                                 return range;
492                         } else {
493                                 return -range;
494                         }
495
496                 } else {
497                         // still one or more bits left, encode value as precisely as
498                         // possible
499
500                         final int encodedValue = (encodedBits >>> 1) + 1;
501
502                         final int realvalueForThisBitcount = 1 << remainingBitCount;
503
504                         // int valueMultiplier = range / realvalueForThisBitcount;
505                         int decodedValue = (range * encodedValue)
506                                         / realvalueForThisBitcount;
507
508                         if (decodedValue > range) {
509                                 decodedValue = range;
510                         }
511
512                         if (negativeBit == 0) {
513                                 return decodedValue;
514                         } else {
515                                 return -decodedValue;
516                         }
517
518                 }
519         }
520
521         public static int encodeValueIntoGivenBits(int value, final int range,
522                         final int bitCount) {
523
524                 int negativeBit = 0;
525
526                 if (value < 0) {
527                         negativeBit = 1;
528                         value = -value;
529                 }
530
531                 final int remainingBitCount = bitCount - 1;
532
533                 if (remainingBitCount == 0) {
534                         // no more bits remaining to encode actual value
535
536                         return negativeBit;
537
538                 } else {
539                         // still one or more bits left, encode value as precisely as
540                         // possible
541
542                         if (value > range) {
543                                 value = range;
544                         }
545
546                         final int realvalueForThisBitcount = 1 << remainingBitCount;
547                         // int valueMultiplier = range / realvalueForThisBitcount;
548                         int encodedValue = (value * realvalueForThisBitcount) / range;
549
550                         if (encodedValue >= realvalueForThisBitcount) {
551                                 encodedValue = realvalueForThisBitcount - 1;
552                         }
553
554                         encodedValue = (encodedValue << 1) + negativeBit;
555
556                         return encodedValue;
557                 }
558         }
559
560 }