initial commit
[imagesqueeze.git] / src / main / java / eu / svjatoslav / imagesqueeze / codec / ImageDecoder.java
1 package eu.svjatoslav.imagesqueeze.codec;
2
3 /**
4  * Compressed image pixels decoder.
5  */
6
7 import java.awt.image.DataBufferByte;
8 import java.awt.image.WritableRaster;
9 import java.io.IOException;
10
11 public class ImageDecoder {
12
13         int width, height;
14         Image image;
15
16         byte [] decodedYRangeMap;
17         byte [] decodedYMap;
18
19         byte [] decodedURangeMap;
20         byte [] decodedUMap;
21
22         byte [] decodedVRangeMap;
23         byte [] decodedVMap;
24
25         Color tmpColor = new Color();
26         
27         Approximator approximator;
28         BitInputStream bitInputStream;
29         
30         ColorStats colorStats = new ColorStats();
31         OperatingContext context = new OperatingContext();
32         
33         public ImageDecoder (Image image, BitInputStream bitInputStream) {
34                 approximator = new Approximator();
35
36                 this.image = image;     
37                 this.bitInputStream = bitInputStream;
38
39                 width = image.metaData.width;
40                 height = image.metaData.height;
41
42                 decodedYRangeMap = new byte[width * height];
43                 decodedYRangeMap[0] = (byte)(255);
44                 decodedYMap = new byte[width * height];
45
46                 decodedURangeMap = new byte[width * height];
47                 decodedURangeMap[0] = (byte)(255);
48                 decodedUMap = new byte[width * height];
49
50                 decodedVRangeMap = new byte[width * height];
51                 decodedVRangeMap[0] = (byte)(255);
52                 decodedVMap = new byte[width * height];
53
54         }
55
56
57         public void decode() throws IOException {
58                 approximator.load(bitInputStream);
59                 approximator.computeLookupTables();
60                 
61                 WritableRaster raster = image.bufferedImage.getRaster();
62                 DataBufferByte dbi = (DataBufferByte)raster.getDataBuffer();
63                 byte [] pixels = dbi.getData();
64
65                 // load top-, left-most pixel. 
66                 decodedYMap[0] = (byte)bitInputStream.readBits(8);
67                 decodedUMap[0] = (byte)bitInputStream.readBits(8);
68                 decodedVMap[0] = (byte)bitInputStream.readBits(8);
69                 
70                 Color color = new Color();
71                 color.y = ImageEncoder.byteToInt(decodedYMap[0]);
72                 color.u = ImageEncoder.byteToInt(decodedUMap[0]);
73                 color.v = ImageEncoder.byteToInt(decodedVMap[0]);
74                 
75                 color.YUV2RGB();
76                 
77                 pixels[0] = (byte)color.r;
78                 pixels[0+1] = (byte)color.g;
79                 pixels[0+2] = (byte)color.b;                            
80
81                                 
82                 // detect initial step
83                 int largestDimension;
84                 int initialStep = 2;
85                 if (width > height) {
86                         largestDimension = width;
87                 } else {
88                         largestDimension = height;
89                 }
90                 
91                 while (initialStep < largestDimension){
92                         initialStep = initialStep * 2;
93                 }
94
95                 grid(initialStep, pixels);
96         }
97
98
99         public void grid(int step, byte [] pixels) throws IOException {
100
101                 gridDiagonal(step / 2, step / 2, step, pixels);
102                 gridSquare(step / 2, 0, step, pixels);
103                 gridSquare(0, step / 2, step, pixels);
104
105                 if (step > 2) grid(step / 2, pixels);
106         }
107
108
109         public void gridSquare(int offsetX, int offsetY, int step, byte [] pixels) throws IOException{
110
111                 for (int y = offsetY; y < height; y = y + step){
112                         for (int x = offsetX; x < width; x = x + step){
113
114
115                                 int halfStep = step / 2;
116                                 
117                                 context.initialize(image, decodedYMap, decodedUMap, decodedVMap);
118                                 context.measureNeighborEncode(x - halfStep, y);
119                                 context.measureNeighborEncode(x + halfStep, y);
120                                 context.measureNeighborEncode(x, y - halfStep);
121                                 context.measureNeighborEncode(x, y + halfStep);
122                                                                 
123                                 loadPixel(step, offsetX, offsetY,  x, y, pixels,
124                                                 context.colorStats.getAverageY(),
125                                                 context.colorStats.getAverageU(),
126                                                 context.colorStats.getAverageV());
127
128                         }                       
129                 }               
130         }
131
132
133         public void gridDiagonal(int offsetX, int offsetY, int step, byte [] pixels) throws IOException{
134
135                 for (int y = offsetY; y < height; y = y + step){
136                         for (int x = offsetX; x < width; x = x + step){
137
138                                 int halfStep = step / 2;
139
140                                 context.initialize(image, decodedYMap, decodedUMap, decodedVMap);
141                                 context.measureNeighborEncode(x - halfStep, y - halfStep);
142                                 context.measureNeighborEncode(x + halfStep, y - halfStep);
143                                 context.measureNeighborEncode(x - halfStep, y + halfStep);
144                                 context.measureNeighborEncode(x + halfStep, y + halfStep);
145                                                                 
146                                 loadPixel(step, offsetX, offsetY,  x, y, pixels,
147                                                 context.colorStats.getAverageY(),
148                                                 context.colorStats.getAverageU(),
149                                                 context.colorStats.getAverageV());
150
151                         }                       
152                 }               
153         }
154
155
156         public void loadPixel(int step, int offsetX, int offsetY, int x, int y, byte[] pixels,
157                         int averageDecodedY, int averageDecodedU, int averageDecodedV)
158         throws IOException{
159
160                 int index = (y * width) + x;
161
162                 int halfStep = step / 2;
163
164                 int parentIndex;
165                 if (offsetX > 0){
166                         if (offsetY > 0){
167                                 // diagonal approach
168                                 parentIndex = ((y - halfStep) * width) + (x - halfStep);                                                                                                        
169                         } else {
170                                 // take left pixel
171                                 parentIndex = (y * width) + (x - halfStep);                                                                     
172                         }                       
173                 } else {
174                         // take upper pixel
175                         parentIndex = ((y - halfStep) * width) + x;                                                                             
176                 }
177
178
179
180                 int colorBufferIndex = index * 3;
181
182                 Color color = new Color();
183                 color.y = loadChannel(decodedYRangeMap, decodedYMap, approximator.yTable, averageDecodedY, index, parentIndex);
184                 color.u = loadChannel(decodedURangeMap, decodedUMap, approximator.uTable, averageDecodedU, index, parentIndex);
185                 color.v = loadChannel(decodedVRangeMap, decodedVMap, approximator.vTable, averageDecodedV, index, parentIndex);
186                 
187                 color.YUV2RGB();
188                 
189                 pixels[colorBufferIndex] = (byte)color.r;
190                 pixels[colorBufferIndex+1] = (byte)color.g;
191                 pixels[colorBufferIndex+2] = (byte)color.b;                             
192
193         }
194
195
196         private int loadChannel(byte [] decodedRangeMap, byte [] decodedMap, Table table, int averageDecodedValue, int index,
197                         int parentIndex) throws IOException {
198                 int decodedValue = averageDecodedValue;
199
200                 int inheritedRange = ImageEncoder.byteToInt(decodedRangeMap[parentIndex]);
201                 int computedRange = inheritedRange;
202
203                 int bitCount = table.proposeBitcountForRange(inheritedRange);
204                 int computedRangeBitCount = 0;
205                 if ( bitCount > 0){
206
207                         int rangeDecreases = bitInputStream.readBits(1);
208                         if (rangeDecreases != 0){
209                                 computedRange = table.proposeDecreasedRange(inheritedRange);
210                         }                               
211
212                         decodedRangeMap[index] = (byte)computedRange;
213                         computedRangeBitCount = table.proposeBitcountForRange(computedRange);
214
215                         if (computedRangeBitCount > 0){
216
217                                 int encodedDifference = bitInputStream.readBits(computedRangeBitCount);
218
219                                 int decodedDifference = ImageEncoder.decodeValueFromGivenBits(encodedDifference, computedRange, computedRangeBitCount);
220
221                                 decodedValue = averageDecodedValue - decodedDifference;
222                                 if (decodedValue > 255) decodedValue = 255;
223                                 if (decodedValue < 0) decodedValue = 0;
224                         }
225                 } else {
226                         decodedRangeMap[index] = (byte)inheritedRange;
227                 }
228                 decodedMap[index] = (byte)decodedValue;
229                 return decodedValue;
230         }
231
232 }