initial commit
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageEncoder.java
1 package eu.svjatoslav.imagesqueeze.codec;
2
3 /**
4  * Compressed image pixels encoder.
5  */
6
7 import java.awt.image.DataBufferByte;
8 import java.awt.image.WritableRaster;
9 import java.io.IOException;
10
11
12 public class ImageEncoder {
13
14         Image image;
15         int width, height;
16
17         Channel yChannel;
18         Channel uChannel;
19         Channel vChannel;
20         
21         Approximator approximator;
22
23         int bitsForY;
24         int bitsForU;
25         int bitsForV;
26         
27         //ColorStats colorStats = new ColorStats();
28         OperatingContext context = new OperatingContext();
29         OperatingContext context2 = new OperatingContext();
30
31         BitOutputStream bitOutputStream;
32
33         public ImageEncoder(Image image){
34                 approximator = new Approximator();
35
36                 //bitOutputStream = outputStream;
37
38                 this.image = image;
39
40         }
41
42
43         public void encode(BitOutputStream bitOutputStream) throws IOException {
44                 this.bitOutputStream = bitOutputStream;
45                 
46                 approximator.initialize();
47                 
48                 approximator.save(bitOutputStream);
49                 
50                 width = image.metaData.width;
51                 height = image.metaData.height;
52
53                 WritableRaster raster = image.bufferedImage.getRaster();
54                 DataBufferByte dbi = (DataBufferByte)raster.getDataBuffer();
55                 byte [] pixels = dbi.getData();
56
57                 if (yChannel == null){
58                         yChannel = new Channel(width, height);                  
59                 } else {
60                         yChannel.reset();
61                 }
62                 
63                 if (uChannel == null){
64                         uChannel = new Channel(width, height);                  
65                 } else {
66                         uChannel.reset();
67                 }
68                 
69                 if (vChannel == null){
70                         vChannel = new Channel(width, height);                  
71                 } else {
72                         vChannel.reset();
73                 }
74                 
75                 // create YUV map out of RGB raster data
76                 Color color = new Color();
77                 
78                 for (int y=0; y < height; y++){
79                         for (int x=0; x < width; x++){
80
81                                 int index = (y * width) + x;
82                                 int colorBufferIndex = index * 3;
83
84                                 int blue = pixels[colorBufferIndex];                            
85                                 if (blue < 0) blue = blue + 256;
86
87                                 int green = pixels[colorBufferIndex+1];
88                                 if (green < 0) green = green + 256;
89
90                                 int red = pixels[colorBufferIndex+2];
91                                 if (red < 0) red = red + 256;
92
93                                 color.r = red;
94                                 color.g = green;
95                                 color.b = blue;
96
97                                 color.RGB2YUV();
98
99                                 yChannel.map[index] = (byte)color.y;
100                                 uChannel.map[index] = (byte)color.u;
101                                 vChannel.map[index] = (byte)color.v;
102                         }
103                 }
104
105                 yChannel.decodedMap[0] = yChannel.map[0];
106                 uChannel.decodedMap[0] = uChannel.map[0];
107                 vChannel.decodedMap[0] = vChannel.map[0];
108                 
109                 bitOutputStream.storeBits(byteToInt(yChannel.map[0]), 8);                                                               
110                 bitOutputStream.storeBits(byteToInt(uChannel.map[0]), 8);                                                               
111                 bitOutputStream.storeBits(byteToInt(vChannel.map[0]), 8);                                                               
112
113                 // detect initial step
114                 int largestDimension;
115                 int initialStep = 2;
116                 if (width > height) {
117                         largestDimension = width;
118                 } else {
119                         largestDimension = height;
120                 }
121                 
122                 while (initialStep < largestDimension){
123                         initialStep = initialStep * 2;
124                 }
125                                 
126                 rangeGrid(initialStep);
127                 rangeRoundGrid(2);
128                 saveGrid(initialStep);
129         }
130
131         public void printStatistics(){
132                 System.out.println("Y channel:");
133                 yChannel.printStatistics();
134
135                 System.out.println("U channel:");
136                 uChannel.printStatistics();
137                 
138                 System.out.println("V channel:");
139                 vChannel.printStatistics();             
140         }
141         
142         public void rangeGrid(int step){
143
144                 //gridSquare(step / 2, step / 2, step, pixels);
145
146                 rangeGridDiagonal(step / 2, step / 2, step);
147                 rangeGridSquare(step / 2, 0, step);
148                 rangeGridSquare(0, step / 2, step);
149
150                 if (step > 2) rangeGrid(step / 2);
151         }
152
153
154         public void rangeRoundGrid(int step){
155
156                 rangeRoundGridDiagonal(step / 2, step / 2, step);
157                 rangeRoundGridSquare(step / 2, 0, step);
158                 rangeRoundGridSquare(0, step / 2, step);
159
160                 if (step < 1024) rangeRoundGrid(step * 2);
161         }
162
163         public void saveGrid(int step) throws IOException {
164
165                 saveGridDiagonal(step / 2, step / 2, step);
166                 saveGridSquare(step / 2, 0, step);
167                 saveGridSquare(0, step / 2, step);
168
169                 if (step > 2) saveGrid(step / 2);
170         }
171
172
173         public void rangeGridSquare(int offsetX, int offsetY, int step){
174                 for (int y = offsetY; y < height; y = y + step){
175                         for (int x = offsetX; x < width; x = x + step){
176
177                                 int index = (y * width) + x;
178                                 int halfStep = step / 2;
179
180                                 context.initialize(image, yChannel.map, uChannel.map, vChannel.map);
181                         
182                                 context.measureNeighborEncode(x - halfStep, y);
183                                 context.measureNeighborEncode(x + halfStep, y);
184                                 context.measureNeighborEncode(x, y - halfStep);
185                                 context.measureNeighborEncode(x, y + halfStep);
186
187                                 yChannel.rangeMap[index] = (byte)context.getYRange(index);
188                                 uChannel.rangeMap[index] = (byte)context.getURange(index);
189                                 vChannel.rangeMap[index] = (byte)context.getVRange(index);
190                         }                       
191                 }               
192         }
193
194         public void rangeGridDiagonal(int offsetX, int offsetY, int step){
195                 for (int y = offsetY; y < height; y = y + step){
196                         for (int x = offsetX; x < width; x = x + step){
197
198                                 int index = (y * width) + x;
199                                 int halfStep = step / 2;
200
201                                 context.initialize(image, yChannel.map, uChannel.map, vChannel.map);
202                                 
203                                 context.measureNeighborEncode(x - halfStep, y - halfStep);
204                                 context.measureNeighborEncode(x + halfStep, y - halfStep);
205                                 context.measureNeighborEncode(x - halfStep, y + halfStep);
206                                 context.measureNeighborEncode(x + halfStep, y + halfStep);
207
208                                 yChannel.rangeMap[index] = (byte)context.getYRange(index);
209                                 uChannel.rangeMap[index] = (byte)context.getURange(index);
210                                 vChannel.rangeMap[index] = (byte)context.getVRange(index);
211                         }                       
212                 }               
213         }
214
215         public void rangeRoundGridDiagonal(int offsetX, int offsetY, int step){
216                 for (int y = offsetY; y < height; y = y + step){
217                         for (int x = offsetX; x < width; x = x + step){
218
219                                 int index = (y * width) + x;
220
221                                 int yRange = byteToInt(yChannel.rangeMap[index]);
222                                 int uRange = byteToInt(uChannel.rangeMap[index]);
223                                 int vRange = byteToInt(vChannel.rangeMap[index]);
224
225                                 int halfStep = step / 2;
226
227                                 int parentIndex = ((y - halfStep) * width) + (x - halfStep);
228
229                                 int parentYRange =  byteToInt(yChannel.rangeMap[parentIndex]);
230
231                                 if (parentYRange < yRange){
232                                         parentYRange = yRange;
233                                         yChannel.rangeMap[parentIndex] = (byte)parentYRange;
234                                 }
235
236                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
237
238                                 if (parentURange < uRange){
239                                         parentURange = uRange;
240                                         uChannel.rangeMap[parentIndex] = (byte)parentURange;
241                                 }
242
243                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
244
245                                 if (parentVRange < vRange){
246                                         parentVRange = vRange;
247                                         vChannel.rangeMap[parentIndex] = (byte)parentVRange;
248                                 }
249                         }                       
250                 }               
251         }
252
253         public void rangeRoundGridSquare(int offsetX, int offsetY, int step){
254                 for (int y = offsetY; y < height; y = y + step){
255                         for (int x = offsetX; x < width; x = x + step){
256
257                                 int index = (y * width) + x;
258                                 
259                                 int yRange = byteToInt(yChannel.rangeMap[index]);
260                                 int uRange = byteToInt(uChannel.rangeMap[index]);
261                                 int vRange = byteToInt(vChannel.rangeMap[index]);
262
263                                 int halfStep = step / 2;
264
265                                 int parentIndex;
266                                 if (offsetX > 0){
267                                         parentIndex = (y * width) + (x - halfStep);                                     
268                                 } else {
269                                         parentIndex = ((y - halfStep) * width) + x;                                                                             
270                                 }
271
272                                 int parentYRange =  byteToInt(yChannel.rangeMap[parentIndex]);
273
274                                 if (parentYRange < yRange){
275                                         parentYRange = yRange;
276                                         yChannel.rangeMap[parentIndex] = (byte)parentYRange;
277                                 }
278
279                                 int parentURange = byteToInt(uChannel.rangeMap[parentIndex]);
280
281                                 if (parentURange < uRange){
282                                         parentURange = uRange;
283                                         uChannel.rangeMap[parentIndex] = (byte)parentURange;
284                                 }
285
286                                 int parentVRange = byteToInt(vChannel.rangeMap[parentIndex]);
287
288                                 if (parentVRange < vRange){
289                                         parentVRange = vRange;
290                                         vChannel.rangeMap[parentIndex] = (byte)parentVRange;
291                                 }
292
293                         }                       
294                 }               
295         }
296
297         public void saveGridSquare(int offsetX, int offsetY, int step) throws IOException{
298                 for (int y = offsetY; y < height; y = y + step){
299                         for (int x = offsetX; x < width; x = x + step){
300
301                                 int halfStep = step / 2;
302
303                                 context2.initialize(image, yChannel.decodedMap, uChannel.decodedMap,  vChannel.decodedMap);
304                                 context2.measureNeighborEncode(x - halfStep, y);
305                                 context2.measureNeighborEncode(x + halfStep, y);
306                                 context2.measureNeighborEncode(x, y - halfStep);
307                                 context2.measureNeighborEncode(x, y + halfStep);
308                         
309                                 
310                                 savePixel(step, offsetX, offsetY, x, y,
311                                                 context2.colorStats.getAverageY(),
312                                                 context2.colorStats.getAverageU(),
313                                                 context2.colorStats.getAverageV());                             
314
315                         }                       
316                 }               
317         }
318
319         public void saveGridDiagonal(int offsetX, int offsetY, int step) throws IOException {
320                 for (int y = offsetY; y < height; y = y + step){
321                         for (int x = offsetX; x < width; x = x + step){
322                                 
323                                 int halfStep = step / 2;
324
325                                 context2.initialize(image, yChannel.decodedMap, uChannel.decodedMap,  vChannel.decodedMap);
326                                 context2.measureNeighborEncode(x - halfStep, y - halfStep);
327                                 context2.measureNeighborEncode(x + halfStep, y - halfStep);
328                                 context2.measureNeighborEncode(x - halfStep, y + halfStep);
329                                 context2.measureNeighborEncode(x + halfStep, y + halfStep);
330                         
331                                 
332                                 savePixel(step, offsetX, offsetY, x, y,
333                                                 context2.colorStats.getAverageY(),
334                                                 context2.colorStats.getAverageU(),
335                                                 context2.colorStats.getAverageV());                             
336                                 
337                         }                       
338                 }               
339         }
340
341         public void savePixel(int step, int offsetX, int offsetY, int x, int y, int averageDecodedY, int averageDecodedU, int averageDecodedV) throws IOException {
342
343                 int index = (y * width) + x;
344
345                 int py = byteToInt(yChannel.map[index]);                
346                 int pu = byteToInt(uChannel.map[index]);        
347                 int pv = byteToInt(vChannel.map[index]);                
348                 
349                 int yRange = byteToInt(yChannel.rangeMap[index]);
350                 int uRange = byteToInt(uChannel.rangeMap[index]);
351                 int vRange = byteToInt(vChannel.rangeMap[index]);
352
353                 int halfStep = step / 2;
354
355                 int parentIndex;
356                 if (offsetX > 0){
357                         if (offsetY > 0){
358                                 // diagonal approach
359                                 parentIndex = ((y - halfStep) * width) + (x - halfStep);                                                                                                        
360                         } else {
361                                 // take left pixel
362                                 parentIndex = (y * width) + (x - halfStep);                                                                     
363                         }                       
364                 } else {
365                         // take upper pixel
366                         parentIndex = ((y - halfStep) * width) + x;                                                                             
367                 }
368
369                 encodeChannel(
370                                 approximator.yTable,
371                                 yChannel,
372                                 averageDecodedY,
373                                 index,
374                                 py,
375                                 yRange,
376                                 parentIndex);
377
378                 encodeChannel(
379                                 approximator.uTable,
380                                 uChannel,
381                                 averageDecodedU,
382                                 index,
383                                 pu,
384                                 uRange,
385                                 parentIndex);
386
387                 encodeChannel(
388                                 approximator.vTable,
389                                 vChannel,
390                                 averageDecodedV,
391                                 index,
392                                 pv,
393                                 vRange,
394                                 parentIndex);
395
396         }
397
398
399         private void encodeChannel(Table table, Channel channel, int averageDecodedValue, int index,
400                         int value, int range, int parentIndex)
401                         throws IOException {
402                 
403                 byte[] decodedRangeMap = channel.decodedRangeMap;
404                 byte[] decodedMap = channel.decodedMap;
405                 
406                 int inheritedRange = byteToInt(decodedRangeMap[parentIndex]);
407
408                 int inheritedBitCount = table.proposeBitcountForRange(inheritedRange);
409
410                 if (inheritedBitCount > 0){
411                         int computedRange;
412                         computedRange = table.proposeRangeForRange(range, inheritedRange);                                                                                                                              
413                         decodedRangeMap[index] = (byte)computedRange;
414
415                         channel.bitCount++;
416                         if (computedRange != inheritedRange){
417                                 // brightness range shrinked
418                                 bitOutputStream.storeBits(1, 1);                        
419                         } else {
420                                 // brightness range stayed the same
421                                 bitOutputStream.storeBits(0, 1);                                                
422                         }
423
424
425                         // encode brightness into available amount of bits
426                         int computedBitCount = table.proposeBitcountForRange(computedRange);
427
428                         if (computedBitCount > 0){
429
430                                 int differenceToEncode = -(value - averageDecodedValue);
431                                 int bitEncodedDifference = encodeValueIntoGivenBits(differenceToEncode, computedRange, computedBitCount);
432
433                                 channel.bitCount = channel.bitCount + computedBitCount;
434                                 bitOutputStream.storeBits(bitEncodedDifference, computedBitCount);                                                              
435
436                                 int decodedDifference = decodeValueFromGivenBits(bitEncodedDifference, computedRange, computedBitCount);
437                                 int decodedValue = averageDecodedValue - decodedDifference;
438                                 if (decodedValue > 255) decodedValue = 255;
439                                 if (decodedValue < 0) decodedValue = 0;
440
441                                 decodedMap[index] = (byte)decodedValue;                 
442                         } else {                                
443                                 decodedMap[index] = (byte)averageDecodedValue;                  
444                         }                       
445
446                 } else {
447                         decodedRangeMap[index] = (byte)inheritedRange;                  
448                         decodedMap[index] = (byte)averageDecodedValue;                  
449                 }
450         }
451
452         public static int encodeValueIntoGivenBits(int value, int range, int bitCount){
453
454                 int negativeBit = 0;
455
456                 if (value <0){
457                         negativeBit = 1;
458                         value = -value;
459                 }
460
461                 int remainingBitCount = bitCount - 1;
462
463                 if (remainingBitCount == 0){                    
464                         // no more bits remaining to encode actual value
465
466                         return negativeBit;
467
468                 } else {
469                         // still one or more bits left, encode value as precisely as possible
470
471                         if (value > range) value = range;
472
473
474                         int realvalueForThisBitcount = 1 << remainingBitCount;
475                         // int valueMultiplier = range / realvalueForThisBitcount;
476                         int encodedValue = value * realvalueForThisBitcount / range;
477
478                         if (encodedValue >= realvalueForThisBitcount) encodedValue = realvalueForThisBitcount - 1;
479
480                         encodedValue = (encodedValue << 1) + negativeBit;
481
482                         return encodedValue;
483                 }
484         }
485
486
487         public static int decodeValueFromGivenBits(int encodedBits, int range, int bitCount){
488                 int negativeBit = encodedBits & 1;
489
490                 int remainingBitCount = bitCount - 1;
491
492                 if (remainingBitCount == 0){                    
493                         // no more bits remaining to encode actual value
494
495                         if (negativeBit == 0){
496                                 return range;                           
497                         } else {
498                                 return -range;                                                          
499                         }
500
501                 } else {
502                         // still one or more bits left, encode value as precisely as possible
503
504                         int encodedValue = (encodedBits >>> 1) + 1;
505
506                         int realvalueForThisBitcount = 1 << remainingBitCount;
507
508                         // int valueMultiplier = range / realvalueForThisBitcount;
509                         int decodedValue = range * encodedValue / realvalueForThisBitcount;
510
511
512                         if (decodedValue > range) decodedValue = range;
513
514                         if (negativeBit == 0){
515                                 return decodedValue;                            
516                         } else {
517                                 return -decodedValue;                                                           
518                         }
519
520                 }               
521         }
522
523         public static int byteToInt(byte input){
524                 int result = input;
525                 if (result < 0) result = result + 256;
526                 return result;
527         }
528
529 }