clarified ideas.
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 /*
2  * Imagesqueeze - Image codec optimized for photos.
3  * Copyright (C) 2012, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
4  * 
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of version 2 of the GNU General Public License
7  * as published by the Free Software Foundation.
8  */
9
10 package eu.svjatoslav.imagesqueeze.codec;
11
12 /**
13  * Compressed image pixels encoder.
14  */
15
16 import java.awt.image.DataBufferByte;
17 import java.awt.image.WritableRaster;
18 import java.io.IOException;
19
20 import eu.svjatoslav.commons.data.BitOutputStream;
21
22 public class ImageEncoder {
23
24         Image image;
25         int width, height;
26
27         Channel yChannel;
28         Channel uChannel;
29         Channel vChannel;
30
31         Approximator approximator;
32
33         int bitsForY;
34         int bitsForU;
35         int bitsForV;
36
37         // ColorStats colorStats = new ColorStats();
38         OperatingContext context = new OperatingContext();
39         OperatingContext context2 = new OperatingContext();
40
41         BitOutputStream bitOutputStream;
42
43         public ImageEncoder(final Image image) {
44                 approximator = new Approximator();
45
46                 // bitOutputStream = outputStream;
47
48                 this.image = image;
49
50         }
51
52         public void encode(final BitOutputStream bitOutputStream)
53                         throws IOException {
54                 this.bitOutputStream = bitOutputStream;
55
56                 approximator.initialize();
57
58                 approximator.save(bitOutputStream);
59
60                 width = image.metaData.width;
61                 height = image.metaData.height;
62
63                 final WritableRaster raster = image.bufferedImage.getRaster();
64                 final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
65                 final byte[] pixels = dbi.getData();
66
67                 if (yChannel == null) {
68                         yChannel = new Channel(width, height);
69                 } else {
70                         yChannel.reset();
71                 }
72
73                 if (uChannel == null) {
74                         uChannel = new Channel(width, height);
75                 } else {
76                         uChannel.reset();
77                 }
78
79                 if (vChannel == null) {
80                         vChannel = new Channel(width, height);
81                 } else {
82                         vChannel.reset();
83                 }
84
85                 // create YUV map out of RGB raster data
86                 final Color color = new Color();
87
88                 for (int y = 0; y < height; y++) {
89                         for (int x = 0; x < width; x++) {
90
91                                 final int index = (y * width) + x;
92                                 final int colorBufferIndex = index * 3;
93
94                                 int blue = pixels[colorBufferIndex];
95                                 if (blue < 0)
96                                         blue = blue + 256;
97
98                                 int green = pixels[colorBufferIndex + 1];
99                                 if (green < 0)
100                                         green = green + 256;
101
102                                 int red = pixels[colorBufferIndex + 2];
103                                 if (red < 0)
104                                         red = red + 256;
105
106                                 color.r = red;
107                                 color.g = green;
108                                 color.b = blue;
109
110                                 color.RGB2YUV();
111
112                                 yChannel.map[index] = (byte) color.y;
113                                 uChannel.map[index] = (byte) color.u;
114                                 vChannel.map[index] = (byte) color.v;
115                         }
116                 }
117
118                 yChannel.decodedMap[0] = yChannel.map[0];
119                 uChannel.decodedMap[0] = uChannel.map[0];
120                 vChannel.decodedMap[0] = vChannel.map[0];
121
122                 bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);
123                 bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);
124                 bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);
125
126                 // detect initial step
127                 int largestDimension;
128                 int initialStep = 2;
129                 if (width > height) {
130                         largestDimension = width;
131                 } else {
132                         largestDimension = height;
133                 }
134
135                 while (initialStep < largestDimension) {
136                         initialStep = initialStep * 2;
137                 }
138
139                 rangeGrid(initialStep);
140                 rangeRoundGrid(2);
141                 saveGrid(initialStep);
142         }
143
144         private void encodeChannel(final Table table, final Channel channel,
145                         final int averageDecodedValue, final int index, final int value,
146                         final int range, final int parentIndex) throws IOException {
147
148                 final byte[] decodedRangeMap = channel.decodedRangeMap;
149                 final byte[] decodedMap = channel.decodedMap;
150
151                 final int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
152
153                 final int inheritedBitCount = table
154                                 .proposeBitcountForRange(inheritedRange);
155
156                 if (inheritedBitCount > 0) {
157                         int computedRange;
158                         computedRange = table.proposeRangeForRange(range, inheritedRange);
159                         decodedRangeMap[index] = (byte) computedRange;
160
161                         channel.bitCount++;
162                         if (computedRange != inheritedRange) {
163                                 // brightness range shrinked
164                                 bitOutputStream.storeBits(1, 1);
165                         } else {
166                                 // brightness range stayed the same
167                                 bitOutputStream.storeBits(0, 1);
168                         }
169
170                         // encode brightness into available amount of bits
171                         final int computedBitCount = table
172                                         .proposeBitcountForRange(computedRange);
173
174                         if (computedBitCount > 0) {
175
176                                 final int differenceToEncode = -(value - averageDecodedValue);
177                                 final int bitEncodedDifference = encodeValueIntoGivenBits(
178                                                 differenceToEncode, computedRange, computedBitCount);
179
180                                 channel.bitCount = channel.bitCount + computedBitCount;
181                                 bitOutputStream.storeBits(bitEncodedDifference,
182                                                 computedBitCount);
183
184                                 final int decodedDifference = decodeValueFromGivenBits(
185                                                 bitEncodedDifference, computedRange, computedBitCount);
186                                 int decodedValue = averageDecodedValue - decodedDifference;
187                                 if (decodedValue > 255)
188                                         decodedValue = 255;
189                                 if (decodedValue < 0)
190                                         decodedValue = 0;
191
192                                 decodedMap[index] = (byte) decodedValue;
193                         } else {
194                                 decodedMap[index] = (byte) averageDecodedValue;
195                         }
196
197                 } else {
198                         decodedRangeMap[index] = (byte) inheritedRange;
199                         decodedMap[index] = (byte) averageDecodedValue;
200                 }
201         }
202
203         public void printStatistics() {
204                 System.out.println("Y channel:");
205                 yChannel.printStatistics();
206
207                 System.out.println("U channel:");
208                 uChannel.printStatistics();
209
210                 System.out.println("V channel:");
211                 vChannel.printStatistics();
212         }
213
214         public void rangeGrid(final int step) {
215
216                 // gridSquare(step / 2, step / 2, step, pixels);
217
218                 rangeGridDiagonal(step / 2, step / 2, step);
219                 rangeGridSquare(step / 2, 0, step);
220                 rangeGridSquare(0, step / 2, step);
221
222                 if (step > 2)
223                         rangeGrid(step / 2);
224         }
225
226         public void rangeGridDiagonal(final int offsetX, final int offsetY,
227                         final int step) {
228                 for (int y = offsetY; y < height; y = y + step) {
229                         for (int x = offsetX; x < width; x = x + step) {
230
231                                 final int index = (y * width) + x;
232                                 final int halfStep = step / 2;
233
234                                 context.initialize(image, yChannel.map, uChannel.map,
235                                                 vChannel.map);
236
237                                 context.measureNeighborEncode(x - halfStep, y - halfStep);
238                                 context.measureNeighborEncode(x + halfStep, y - halfStep);
239                                 context.measureNeighborEncode(x - halfStep, y + halfStep);
240                                 context.measureNeighborEncode(x + halfStep, y + halfStep);
241
242                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
243                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
244                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
245                         }
246                 }
247         }
248
249         public void rangeGridSquare(final int offsetX, final int offsetY,
250                         final int step) {
251                 for (int y = offsetY; y < height; y = y + step) {
252                         for (int x = offsetX; x < width; x = x + step) {
253
254                                 final int index = (y * width) + x;
255                                 final int halfStep = step / 2;
256
257                                 context.initialize(image, yChannel.map, uChannel.map,
258                                                 vChannel.map);
259
260                                 context.measureNeighborEncode(x - halfStep, y);
261                                 context.measureNeighborEncode(x + halfStep, y);
262                                 context.measureNeighborEncode(x, y - halfStep);
263                                 context.measureNeighborEncode(x, y + halfStep);
264
265                                 yChannel.rangeMap[index] = (byte) context.getYRange(index);
266                                 uChannel.rangeMap[index] = (byte) context.getURange(index);
267                                 vChannel.rangeMap[index] = (byte) context.getVRange(index);
268                         }
269                 }
270         }
271
272         public void rangeRoundGrid(final int step) {
273
274                 rangeRoundGridDiagonal(step / 2, step / 2, step);
275                 rangeRoundGridSquare(step / 2, 0, step);
276                 rangeRoundGridSquare(0, step / 2, step);
277
278                 if (step < 1024)
279                         rangeRoundGrid(step * 2);
280         }
281
282         public void rangeRoundGridDiagonal(final int offsetX, final int offsetY,
283                         final int step) {
284                 for (int y = offsetY; y < height; y = y + step) {
285                         for (int x = offsetX; x < width; x = x + step) {
286
287                                 final int index = (y * width) + x;
288
289                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
290                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
291                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
292
293                                 final int halfStep = step / 2;
294
295                                 final int parentIndex = ((y - halfStep) * width)
296                                                 + (x - halfStep);
297
298                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
299
300                                 if (parentYRange < yRange) {
301                                         parentYRange = yRange;
302                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
303                                 }
304
305                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
306
307                                 if (parentURange < uRange) {
308                                         parentURange = uRange;
309                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
310                                 }
311
312                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
313
314                                 if (parentVRange < vRange) {
315                                         parentVRange = vRange;
316                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
317                                 }
318                         }
319                 }
320         }
321
322         public void rangeRoundGridSquare(final int offsetX, final int offsetY,
323                         final int step) {
324                 for (int y = offsetY; y < height; y = y + step) {
325                         for (int x = offsetX; x < width; x = x + step) {
326
327                                 final int index = (y * width) + x;
328
329                                 final int yRange = byteToInt(yChannel.rangeMap[index]);
330                                 final int uRange = byteToInt(uChannel.rangeMap[index]);
331                                 final int vRange = byteToInt(vChannel.rangeMap[index]);
332
333                                 final int halfStep = step / 2;
334
335                                 int parentIndex;
336                                 if (offsetX > 0) {
337                                         parentIndex = (y * width) + (x - halfStep);
338                                 } else {
339                                         parentIndex = ((y - halfStep) * width) + x;
340                                 }
341
342                                 int parentYRange = byteToInt(yChannel.rangeMap[parentIndex]);
343
344                                 if (parentYRange < yRange) {
345                                         parentYRange = yRange;
346                                         yChannel.rangeMap[parentIndex] = (byte) parentYRange;
347                                 }
348
349                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
350
351                                 if (parentURange < uRange) {
352                                         parentURange = uRange;
353                                         uChannel.rangeMap[parentIndex] = (byte) parentURange;
354                                 }
355
356                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
357
358                                 if (parentVRange < vRange) {
359                                         parentVRange = vRange;
360                                         vChannel.rangeMap[parentIndex] = (byte) parentVRange;
361                                 }
362
363                         }
364                 }
365         }
366
367         public void saveGrid(final int step) throws IOException {
368
369                 saveGridDiagonal(step / 2, step / 2, step);
370                 saveGridSquare(step / 2, 0, step);
371                 saveGridSquare(0, step / 2, step);
372
373                 if (step > 2)
374                         saveGrid(step / 2);
375         }
376
377         public void saveGridDiagonal(final int offsetX, final int offsetY,
378                         final int step) throws IOException {
379                 for (int y = offsetY; y < height; y = y + step) {
380                         for (int x = offsetX; x < width; x = x + step) {
381
382                                 final int halfStep = step / 2;
383
384                                 context2.initialize(image, yChannel.decodedMap,
385                                                 uChannel.decodedMap, vChannel.decodedMap);
386                                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
387                                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
388                                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
389                                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
390
391                                 savePixel(step, offsetX, offsetY, x, y,
392                                                 context2.colorStats.getAverageY(),
393                                                 context2.colorStats.getAverageU(),
394                                                 context2.colorStats.getAverageV());
395
396                         }
397                 }
398         }
399
400         public void saveGridSquare(final int offsetX, final int offsetY,
401                         final int step) throws IOException {
402                 for (int y = offsetY; y < height; y = y + step) {
403                         for (int x = offsetX; x < width; x = x + step) {
404
405                                 final int halfStep = step / 2;
406
407                                 context2.initialize(image, yChannel.decodedMap,
408                                                 uChannel.decodedMap, vChannel.decodedMap);
409                                 context2.measureNeighborEncode(x - halfStep, y);
410                                 context2.measureNeighborEncode(x + halfStep, y);
411                                 context2.measureNeighborEncode(x, y - halfStep);
412                                 context2.measureNeighborEncode(x, y + halfStep);
413
414                                 savePixel(step, offsetX, offsetY, x, y,
415                                                 context2.colorStats.getAverageY(),
416                                                 context2.colorStats.getAverageU(),
417                                                 context2.colorStats.getAverageV());
418
419                         }
420                 }
421         }
422
423         public void savePixel(final int step, final int offsetX, final int offsetY,
424                         final int x, final int y, final int averageDecodedY,
425                         final int averageDecodedU, final int averageDecodedV)
426                         throws IOException {
427
428                 final int index = (y * width) + x;
429
430                 final int py = byteToInt(yChannel.map[index]);
431                 final int pu = byteToInt(uChannel.map[index]);
432                 final int pv = byteToInt(vChannel.map[index]);
433
434                 final int yRange = byteToInt(yChannel.rangeMap[index]);
435                 final int uRange = byteToInt(uChannel.rangeMap[index]);
436                 final int vRange = byteToInt(vChannel.rangeMap[index]);
437
438                 final int halfStep = step / 2;
439
440                 int parentIndex;
441                 if (offsetX > 0) {
442                         if (offsetY > 0) {
443                                 // diagonal approach
444                                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
445                         } else {
446                                 // take left pixel
447                                 parentIndex = (y * width) + (x - halfStep);
448                         }
449                 } else {
450                         // take upper pixel
451                         parentIndex = ((y - halfStep) * width) + x;
452                 }
453
454                 encodeChannel(approximator.yTable, yChannel, averageDecodedY, index,
455                                 py, yRange, parentIndex);
456
457                 encodeChannel(approximator.uTable, uChannel, averageDecodedU, index,
458                                 pu, uRange, parentIndex);
459
460                 encodeChannel(approximator.vTable, vChannel, averageDecodedV, index,
461                                 pv, vRange, parentIndex);
462
463         }
464
465         public static int byteToInt(final byte input) {
466                 int result = input;
467                 if (result < 0)
468                         result = result + 256;
469                 return result;
470         }
471
472         public static int decodeValueFromGivenBits(final int encodedBits,
473                         final int range, final int bitCount) {
474                 final int negativeBit = encodedBits & 1;
475
476                 final int remainingBitCount = bitCount - 1;
477
478                 if (remainingBitCount == 0) {
479                         // no more bits remaining to encode actual value
480
481                         if (negativeBit == 0) {
482                                 return range;
483                         } else {
484                                 return -range;
485                         }
486
487                 } else {
488                         // still one or more bits left, encode value as precisely as
489                         // possible
490
491                         final int encodedValue = (encodedBits >>> 1) + 1;
492
493                         final int realvalueForThisBitcount = 1 << remainingBitCount;
494
495                         // int valueMultiplier = range / realvalueForThisBitcount;
496                         int decodedValue = (range * encodedValue)
497                                         / realvalueForThisBitcount;
498
499                         if (decodedValue > range)
500                                 decodedValue = range;
501
502                         if (negativeBit == 0) {
503                                 return decodedValue;
504                         } else {
505                                 return -decodedValue;
506                         }
507
508                 }
509         }
510
511         public static int encodeValueIntoGivenBits(int value, final int range,
512                         final int bitCount) {
513
514                 int negativeBit = 0;
515
516                 if (value < 0) {
517                         negativeBit = 1;
518                         value = -value;
519                 }
520
521                 final int remainingBitCount = bitCount - 1;
522
523                 if (remainingBitCount == 0) {
524                         // no more bits remaining to encode actual value
525
526                         return negativeBit;
527
528                 } else {
529                         // still one or more bits left, encode value as precisely as
530                         // possible
531
532                         if (value > range)
533                                 value = range;
534
535                         final int realvalueForThisBitcount = 1 << remainingBitCount;
536                         // int valueMultiplier = range / realvalueForThisBitcount;
537                         int encodedValue = (value * realvalueForThisBitcount) / range;
538
539                         if (encodedValue >= realvalueForThisBitcount)
540                                 encodedValue = realvalueForThisBitcount - 1;
541
542                         encodedValue = (encodedValue << 1) + negativeBit;
543
544                         return encodedValue;
545                 }
546         }
547
548 }