Skip to content

Instantly share code, notes, and snippets.

@Themaister
Created August 23, 2012 07:12
Show Gist options
  • Save Themaister/3433739 to your computer and use it in GitHub Desktop.
Save Themaister/3433739 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import math
def conv2d(r, c, data):
length_r = min(len(r), len(data))
length_c = min(len(c), len(data[0]))
conv_sum = 0
for v in range(length_r):
for u in range(length_c):
conv_sum += data[v][u] * r[v] * c[u]
return conv_sum
def idct2d(data):
length_c = len(data[0])
length_r = len(data)
dct_ret = [[0 for i in range(length_c)] for j in range(length_r)]
for v in range(length_r):
for u in range(length_c):
cos_val_r = [math.sqrt((1.0 if y == 0 else 2.0) / length_r) * math.cos(math.pi * (v + 0.5) * y / length_r) for y in range(length_r)]
cos_val_c = [math.sqrt((1.0 if x == 0 else 2.0) / length_c) * math.cos(math.pi * (u + 0.5) * x / length_c) for x in range(length_c)]
dct_ret[v][u] = round(conv2d(cos_val_r, cos_val_c, data))
return dct_ret
def dct2d(data):
length_c = len(data[0])
length_r = len(data)
dct_ret = [[0 for i in range(length_c)] for j in range(length_r)]
for v in range(length_r):
for u in range(length_c):
cos_val_r = [math.cos(math.pi * (y + 0.5) * v / length_r) for y in range(length_r)]
cos_val_c = [math.cos(math.pi * (x + 0.5) * u / length_c) for x in range(length_c)]
a_r = math.sqrt((1.0 if v == 0 else 2.0) / length_r)
a_c = math.sqrt((1.0 if u == 0 else 2.0) / length_c)
for i in range(len(cos_val_r)):
cos_val_r[i] *= a_r
for i in range(len(cos_val_c)):
cos_val_c[i] *= a_c
dct_ret[v][u] = conv2d(cos_val_r, cos_val_c, data)
return dct_ret
def quantize(data, quant):
return [[round(data[y][x] / quant[y][x]) for x in range(len(data[0]))] for y in range(len(data))]
def dequantize(data, quant):
return [[data[y][x] * quant[y][x] for x in range(len(data[0]))] for y in range(len(data))]
data = [
[-76, -73, -67, -62, -58, -67, -64, -55],
[-65, -69, -73, -38, -19, -43, -59, -56],
[-66, -69, -60, -15, 16, -24, -62, -55],
[-65, -70, -57, -6, 26, -22, -58, -59],
[-61, -67, -60, -24, -2, -40, -60, -58],
[-49, -63, -68, -58, -51, -60, -70, -53],
[-43, -57, -64, -69, -73, -67, -63, -45],
[-41, -49, -59, -60, -63, -52, -50, -34]
]
quant = [
[16, 11, 10, 16, 24, 40, 51, 61],
[12, 12, 14, 19, 26, 58, 60, 55],
[14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62],
[18, 22, 37, 56, 68, 109, 103, 77],
[24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101],
[72, 92, 95, 98, 112, 100, 103, 99]
]
print("Orig:")
for row in data:
print(row)
dct_data = dct2d(data)
dct_data = quantize(dct_data, quant)
dct_data = dequantize(dct_data, quant)
#print(dct_data)
#print("DCT:")
#for row in dct_data:
# print(row)
orig_data = idct2d(dct_data)
print("iDCT:")
for row in orig_data:
print(row)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment