Last active
August 21, 2021 14:06
-
-
Save KatsuhiroMorishita/5c693e23c725f373181f600461b01ee6 to your computer and use it in GitHub Desktop.
ビタビアルゴリズムに基づくデジタルデータの符号化・復号のサンプルコードです。デジタルデータを符号化したい、復号したい場合にお使いください。main関数内のcircuitという関数の実装を変更すれば様々な符号化器を試せます。なお、現状では復号時にレジスタの全パターンを生成していますのでマイコンでは動作しないと思います。また、連続動作時の高速化の余地もたくさんありますが、軽くググった感じでは歴史のあるアルゴリズムの割にヒットしないので有用かと思います。学習用にどうぞ。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ビタビアルゴリズムに基づくデジタルデータの符号化・復号のサンプルコードである。 | |
# memo: 半端かもしれないが、クラス化した。 | |
# 現状ではストリームデータには対応していないが、拡張は簡単だと思う。 | |
# ref. http://www.mobile.ecei.tohoku.ac.jp/lecture/coding/coding_06.pdf | |
# author: Katsuhiro Morishita | |
# created: 2017-06-12 | |
# lisence: MIT | |
import copy | |
import numpy as np | |
class viterbi_encoder: | |
""" ビタビ符号に基づく送信データの符号化器 | |
""" | |
def __init__(self, register_size, circuit_func): | |
""" コンストラクタ | |
register_size: int, レジスタのサイズ。2bit格納なら2 | |
circuit_func: function, レジスタと入力値を基に出力と次のレジスタの状態を返す関数 | |
""" | |
self._reg_size = register_size | |
self._register = 0 | |
self._circuit_func = circuit_func | |
def _split(self, val): | |
""" 値を下位ビットからlistに格納して返す | |
""" | |
ans = [] | |
for i in range(self._reg_size): | |
ans.append((val >> (self._reg_size - i - 1)) & 0x01) | |
return ans | |
def encode(self, data): | |
""" 符号化する | |
data: list, this list has 0 or 1. | |
""" | |
reg = 0 | |
buff = data + [0] * self._reg_size # [0, ...]の追加はレジスタの最終状態を固定化させるため | |
output = [] | |
for x in buff: | |
out_val, reg = self._circuit_func(reg, x) | |
output += self._split(out_val) | |
return output | |
def get_register_paths(packet_size, reg_size=2): | |
""" トレリス線図の状態遷移を作る | |
格納されるのはレジスタの値 | |
packet_size: int, | |
""" | |
reg_paths = [[0]] | |
for x in range(packet_size): | |
_route = [] | |
for mem in reg_paths: | |
_route.append(mem[:]) | |
val1 = mem[-1] >> 1 | |
_route[-1].append(val1) # 最上位ビットに0が追加されたことと同じ | |
if x < packet_size - 2: | |
_route.append(mem[:]) | |
val2 = val1 | (1<<(reg_size - 1)) | |
_route[-1].append(val2) | |
reg_paths = _route | |
#print(len(reg_paths), reg_paths) | |
return reg_paths | |
def get_rx_paths(reg_paths, circuit_func, reg_size=2): | |
"""トレリス線図に基づく受信データ系列を作る | |
reg_paths: list<list>, トレリス線図でのジスタの状態遷移が格納されたリスト | |
circuit_func: function, レジスタと入力値を基に出力と次のレジスタの状態を返す関数 | |
register_size: int, レジスタのサイズ。2bit格納なら2 | |
""" | |
rx_paths = [] | |
for x in reg_paths: | |
_route = [] | |
for i in range(len(x)-1): | |
val, dummy = circuit_func(x[i], x[i+1]>>(reg_size-1)) # 次のレジスタの最上位ビットが今の入力データなのでビット演算で取り出す | |
_route.append(val) | |
rx_paths.append(_route) | |
rx_paths = np.array(rx_paths) # 行列状になる | |
#print(rx_paths) | |
return rx_paths | |
def get_tx_data(reg_path, reg_size=2): | |
""" レジスタの状態変化から送信されたデータを求める | |
reg_path: list<int> レジスタの遷移系列 | |
register_size: int, レジスタのサイズ。2bit格納なら2 | |
""" | |
tx_data = [] | |
for x in reg_path: | |
k = (x & 0b10) >> (reg_size - 1) # パスの最上位ビットが送信データに等しい | |
tx_data.append(k) | |
tx_data = tx_data[1:-2] | |
return tx_data | |
def candidate_reg_history(rx_paths, reg_paths, rx_data, n=3): | |
""" 最尤推定に基づくレジスタ遷移の候補を上位n個を返す | |
rx_data: list, 受信系列。例えば、110001なら、[3, 0, 1] | |
""" | |
diff = rx_paths - rx_data # 距離の計測 | |
#for x in diff: | |
# print(x) | |
d = [np.linalg.norm(x) for x in diff] # あり得る受信データとの間での距離を計測 | |
sorted_d = sorted(d) # 距離の短い順に並び替え | |
reg_histories = [] | |
for i in range(n): | |
index = d.index(sorted_d[i]) # 距離の短い順にインデックス(要素番号)を取得 | |
#rx_path = rx_paths[index] | |
reg_path = reg_paths[index] | |
#print("--{0}--, index {1}".format(i, index)) | |
#print("error", sorted_d[i]) | |
#print("rx path: ", rx_path) # 推定される受信データ系列 | |
#print("reg path: ", reg_path) # 推定されるレジスタの状態遷移系列 | |
reg_histories.append((reg_path, sorted_d[i])) | |
d[index] = 100 # 同値のデータが有った場合への対応のため、取り出し済みのデータには適当に大きな数値を書き込む | |
return reg_histories | |
class viterbi_decoder: | |
""" ビタビ符号に基づく受信データの復号器 | |
""" | |
def __init__(self, register_size, circuit_func): | |
""" コンストラクタ | |
register_size: int, レジスタのサイズ。2bit格納なら2 | |
circuit_func: function, レジスタと入力値を基に出力と次のレジスタの状態を返す関数 | |
""" | |
self._reg_size = register_size | |
self._circuit_func = circuit_func | |
def decode(self, rx_data, n=1): | |
""" 受信系列を基に最尤推定された送信データをn個返す | |
rx_data: list, 受信系列。例えば、110001なら、[3, 0, 1] | |
""" | |
ans = [] | |
reg_paths = get_register_paths(len(rx_data), self._reg_size) | |
rx_paths = get_rx_paths(reg_paths, self._circuit_func, self._reg_size) | |
reg_histories = candidate_reg_history(rx_paths, reg_paths, rx_data, n) # レジスタ状態遷移の最尤推定 | |
for reg, error in reg_histories: | |
# 送信されたデータを求める | |
tx_data = get_tx_data(reg) | |
ans.append((tx_data, error)) | |
return ans | |
def main(): | |
# 符号化と復号のテスト | |
reg_size = 2 | |
def circuit(register, new_val): | |
""" 現在のレジスタの値と新しい入力値で、出力値と次のレジスタを返す | |
符号化器の設計に合わせてこの関数の実装を変更する。 | |
register: int, 現在のレジスタ値(2進数で見たときの桁の0と1がそれぞれのレジスタを表現する) | |
new_val: int, new_val is 0 or 1. 新しい入力値 | |
register_size: int, レジスタのサイズ。2bit格納なら2 | |
""" | |
# レジスタの状態に応じた出力の計算 | |
_input = [new_val] | |
for i in range(reg_size): | |
b = (register >> (reg_size-i-1)) & 0x01 | |
_input.append(b) | |
c1 = (_input[0] + _input[1] + _input[2]) & 0x01 | |
c0 = (_input[0] + _input[2]) & 0x01 | |
c = (c1 << 1) + c0 | |
new_register = (register >> 1) | (new_val << (reg_size-1)) | |
return c, new_register | |
# 符号化器のテスト | |
encoder = viterbi_encoder(2, circuit) # 渡す関数を変えれば様々な符号化を試せる | |
print(encoder.encode([1,1,0,0,0,1,0])) | |
# 個々の関数のチェック | |
#test_data = np.array([3, 1, 1, 3, 0, 3, 2, 3, 0]) # ミスなしのデータ | |
test_data = np.array([3, 0, 1, 1, 0, 3, 2, 3, 0]) # 一部誤ったデータ | |
reg_paths = get_register_paths(len(test_data), reg_size) | |
rx_paths = get_rx_paths(reg_paths, circuit, reg_size) | |
reg_histories = candidate_reg_history(rx_paths, reg_paths, test_data) | |
for reg, error in reg_histories: | |
# 送信されたデータを求める | |
tx_data = get_tx_data(reg) | |
print("tx data", error, tx_data) | |
# 復号器のテスト | |
decoder = viterbi_decoder(2, circuit) | |
print(decoder.decode(test_data, 3)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment