24c1e276663fbf677a341bbec2a439253e86bbdd
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageDecoder.java
1 /*
2  * Imagesqueeze - Image codec optimized for photos.
3  * Copyright (C) 2012, Svjatoslav Agejenko, svjatoslav@svjatoslav.eu
4  *
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of version 2 of the GNU General Public License
7  * as published by the Free Software Foundation.
8  */
9
10 package eu.svjatoslav.imagesqueeze.codec;
11
12 /**
13  * Compressed image pixels decoder.
14  */
15
16 import eu.svjatoslav.commons.data.BitInputStream;
17
18 import java.awt.image.DataBufferByte;
19 import java.awt.image.WritableRaster;
20 import java.io.IOException;
21
22 class ImageDecoder {
23
24     private final int width;
25     private final int height;
26     private final Image image;
27     private final byte[] decodedYRangeMap;
28     private final byte[] decodedYMap;
29     private final byte[] decodedURangeMap;
30     private final byte[] decodedUMap;
31     private final byte[] decodedVRangeMap;
32     private final byte[] decodedVMap;
33     private final Approximator approximator;
34     private final BitInputStream bitInputStream;
35     private final OperatingContext context = new OperatingContext();
36
37     public ImageDecoder(final Image image, final BitInputStream bitInputStream) {
38         approximator = new Approximator();
39
40         this.image = image;
41         this.bitInputStream = bitInputStream;
42
43         width = image.metaData.width;
44         height = image.metaData.height;
45
46         decodedYRangeMap = new byte[width * height];
47         decodedYRangeMap[0] = (byte) (255);
48         decodedYMap = new byte[width * height];
49
50         decodedURangeMap = new byte[width * height];
51         decodedURangeMap[0] = (byte) (255);
52         decodedUMap = new byte[width * height];
53
54         decodedVRangeMap = new byte[width * height];
55         decodedVRangeMap[0] = (byte) (255);
56         decodedVMap = new byte[width * height];
57
58     }
59
60     public static int readIntegerCompressed8(final BitInputStream inputStream)
61             throws IOException {
62
63         if (inputStream.readBits(1) == 0)
64             return inputStream.readBits(8);
65         else
66             return inputStream.readBits(32);
67     }
68
69     public void decode() throws IOException {
70         approximator.load(bitInputStream);
71         approximator.computeLookupTables();
72
73         final WritableRaster raster = image.bufferedImage.getRaster();
74         final DataBufferByte dbi = (DataBufferByte) raster.getDataBuffer();
75         final byte[] pixels = dbi.getData();
76
77         // load top-, left-most pixel.
78         decodedYMap[0] = (byte) bitInputStream.readBits(8);
79         decodedUMap[0] = (byte) bitInputStream.readBits(8);
80         decodedVMap[0] = (byte) bitInputStream.readBits(8);
81
82         final Color color = new Color();
83         color.y = ImageEncoder.byteToInt(decodedYMap[0]);
84         color.u = ImageEncoder.byteToInt(decodedUMap[0]);
85         color.v = ImageEncoder.byteToInt(decodedVMap[0]);
86
87         color.YUV2RGB();
88
89         pixels[0] = (byte) color.r;
90         pixels[0 + 1] = (byte) color.g;
91         pixels[0 + 2] = (byte) color.b;
92
93         // detect initial step
94         int largestDimension;
95         int initialStep = 2;
96         if (width > height)
97             largestDimension = width;
98         else
99             largestDimension = height;
100
101         while (initialStep < largestDimension)
102             initialStep = initialStep * 2;
103
104         grid(initialStep, pixels);
105     }
106
107     private void grid(final int step, final byte[] pixels) throws IOException {
108
109         gridDiagonal(step / 2, step / 2, step, pixels);
110         gridSquare(step / 2, 0, step, pixels);
111         gridSquare(0, step / 2, step, pixels);
112
113         if (step > 2)
114             grid(step / 2, pixels);
115     }
116
117     private void gridDiagonal(final int offsetX, final int offsetY,
118                               final int step, final byte[] pixels) throws IOException {
119
120         for (int y = offsetY; y < height; y = y + step)
121             for (int x = offsetX; x < width; x = x + step) {
122
123                 final int halfStep = step / 2;
124
125                 context.initialize(image, decodedYMap, decodedUMap, decodedVMap);
126                 context.measureNeighborEncode(x - halfStep, y - halfStep);
127                 context.measureNeighborEncode(x + halfStep, y - halfStep);
128                 context.measureNeighborEncode(x - halfStep, y + halfStep);
129                 context.measureNeighborEncode(x + halfStep, y + halfStep);
130
131                 loadPixel(step, offsetX, offsetY, x, y, pixels,
132                         context.colorStats.getAverageY(),
133                         context.colorStats.getAverageU(),
134                         context.colorStats.getAverageV());
135
136             }
137     }
138
139     private void gridSquare(final int offsetX, final int offsetY,
140                             final int step, final byte[] pixels) throws IOException {
141
142         for (int y = offsetY; y < height; y = y + step)
143             for (int x = offsetX; x < width; x = x + step) {
144
145                 final int halfStep = step / 2;
146
147                 context.initialize(image, decodedYMap, decodedUMap, decodedVMap);
148                 context.measureNeighborEncode(x - halfStep, y);
149                 context.measureNeighborEncode(x + halfStep, y);
150                 context.measureNeighborEncode(x, y - halfStep);
151                 context.measureNeighborEncode(x, y + halfStep);
152
153                 loadPixel(step, offsetX, offsetY, x, y, pixels,
154                         context.colorStats.getAverageY(),
155                         context.colorStats.getAverageU(),
156                         context.colorStats.getAverageV());
157
158             }
159     }
160
161     private int loadChannel(final byte[] decodedRangeMap,
162                             final byte[] decodedMap, final Table table,
163                             final int averageDecodedValue, final int index,
164                             final int parentIndex) throws IOException {
165         int decodedValue = averageDecodedValue;
166
167         final int inheritedRange = ImageEncoder
168                 .byteToInt(decodedRangeMap[parentIndex]);
169         int computedRange = inheritedRange;
170
171         final int bitCount = table.proposeBitcountForRange(inheritedRange);
172         int computedRangeBitCount;
173         if (bitCount > 0) {
174
175             final int rangeDecreases = bitInputStream.readBits(1);
176             if (rangeDecreases != 0)
177                 computedRange = table.proposeDecreasedRange(inheritedRange);
178
179             decodedRangeMap[index] = (byte) computedRange;
180             computedRangeBitCount = table
181                     .proposeBitcountForRange(computedRange);
182
183             if (computedRangeBitCount > 0) {
184
185                 final int encodedDifference = bitInputStream
186                         .readBits(computedRangeBitCount);
187
188                 final int decodedDifference = ImageEncoder
189                         .decodeValueFromGivenBits(encodedDifference,
190                                 computedRange, computedRangeBitCount);
191
192                 decodedValue = averageDecodedValue - decodedDifference;
193                 if (decodedValue > 255)
194                     decodedValue = 255;
195                 if (decodedValue < 0)
196                     decodedValue = 0;
197             }
198         } else
199             decodedRangeMap[index] = (byte) inheritedRange;
200         decodedMap[index] = (byte) decodedValue;
201         return decodedValue;
202     }
203
204     private void loadPixel(final int step, final int offsetX, final int offsetY,
205                            final int x, final int y, final byte[] pixels,
206                            final int averageDecodedY, final int averageDecodedU,
207                            final int averageDecodedV) throws IOException {
208
209         final int index = (y * width) + x;
210
211         final int halfStep = step / 2;
212
213         int parentIndex;
214         if (offsetX > 0) {
215             if (offsetY > 0)
216                 // diagonal approach
217                 parentIndex = ((y - halfStep) * width) + (x - halfStep);
218             else
219                 // take left pixel
220                 parentIndex = (y * width) + (x - halfStep);
221         } else
222             // take upper pixel
223             parentIndex = ((y - halfStep) * width) + x;
224
225         final int colorBufferIndex = index * 3;
226
227         final Color color = new Color();
228         color.y = loadChannel(decodedYRangeMap, decodedYMap,
229                 approximator.yTable, averageDecodedY, index, parentIndex);
230         color.u = loadChannel(decodedURangeMap, decodedUMap,
231                 approximator.uTable, averageDecodedU, index, parentIndex);
232         color.v = loadChannel(decodedVRangeMap, decodedVMap,
233                 approximator.vTable, averageDecodedV, index, parentIndex);
234
235         color.YUV2RGB();
236
237         pixels[colorBufferIndex] = (byte) color.r;
238         pixels[colorBufferIndex + 1] = (byte) color.g;
239         pixels[colorBufferIndex + 2] = (byte) color.b;
240
241     }
242
243 }