"""
pip install soundfile
"""
import os
import argparse
import numpy as np
import tensorflow as tf
import librosa
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import soundfile as sf
import warnings
import pandas as pd

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 只顯示 ERROR

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

path="assets/01/01/mic/" # 資料夾路徑
# 解析命令列參數
parser = argparse.ArgumentParser(description='Train a speech classifier with TensorFlow')
parser.add_argument('--data_dir', type=str, default=path,  help='資料資料夾路徑')
parser.add_argument('--epochs', type=int, default=50, help='訓練 epochs 數')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--model_path', type=str, default=path+'model.h5', help='模型儲存路徑')
parser.add_argument('--weights_path', type=str, default=path+'weights.h5', help='權重儲存路徑')
parser.add_argument('--sr', type=int, default=16000, help='取樣率')
parser.add_argument('--duration', type=float, default=3.0, help='每個音檔長度(秒)')
parser.add_argument('--n_mels', type=int, default=40, help='Mel bins 數量')
parser.add_argument('--history_path', type=str, default=path+'history.csv', help='訓練歷史紀錄CSV儲存路徑')
parser.add_argument('--tflite_path', type=str, default=path+'model.tflite', help='TFLite模型儲存路徑')
parser.add_argument('--header_path', type=str, default=path+'model.h', help='C header檔案儲存路徑')
args = parser.parse_args()

# 掃描資料夾，建立 label
labels = sorted([d for d in os.listdir(args.data_dir) if os.path.isdir(os.path.join(args.data_dir, d))])
label2idx = {label: idx for idx, label in enumerate(labels)}

X = []
y = []

print('Loading data...')
for label in labels:
    folder = os.path.join(args.data_dir, label)
    for fname in os.listdir(folder):
        if fname.lower().endswith('.wav'):
            path = os.path.join(folder, fname)
            try:
                wav, sr = sf.read(path)
                if wav.ndim > 1:
                    wav = wav.mean(axis=1)  # 轉單聲道
                wav = librosa.resample(wav, orig_sr=sr, target_sr=args.sr)
                # 補齊長度
                if len(wav) < int(args.sr * args.duration):
                    wav = np.pad(wav, (0, int(args.sr * args.duration) - len(wav)))
                # 轉 mel spectrogram
                mel = librosa.feature.melspectrogram(y=wav, sr=args.sr, n_mels=args.n_mels)
                mel_db = librosa.power_to_db(mel, ref=np.max)
                X.append(mel_db)
                y.append(label2idx[label])
            except Exception as e:
                print(f'Error processing {path}: {e}')

X = np.array(X)
y = np.array(y)

# reshape for CNN: (samples, height, width, channels)
X = X[..., np.newaxis]
y_cat = to_categorical(y, num_classes=len(labels))

# train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y_cat, test_size=0.2, random_state=42, stratify=y)

# 建立簡單 CNN 模型
def build_model(input_shape, num_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=input_shape),
        tf.keras.layers.MaxPooling2D((2,2)),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2,2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

model = build_model(X_train.shape[1:], len(labels))

print('Training...')
history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=args.epochs,
    batch_size=args.batch_size
)
# 添加測試
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
print(f'-'*20)
print(f'Test accuracy: {test_acc:.4f}')
print(f'-'*20)

# 儲存 history 為 CSV
history_df = pd.DataFrame(history.history)
history_df.to_csv(args.history_path, index=False)
print(f'Training history saved to {args.history_path}')

print(f'Saving model to {args.model_path} and weights to {args.weights_path}')
model.save(args.model_path)
model.save_weights(args.weights_path)

# 轉換為 TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(args.tflite_path, 'wb') as f:
    f.write(tflite_model)
print(f'TFLite model saved to {args.tflite_path}')

# 轉換為 .h 檔

def tflite_to_header(tflite_path, header_path, var_name='g_model'):
    with open(tflite_path, 'rb') as f:
        tflite_bytes = f.read()
    with open(header_path, 'w') as f:
        f.write(f'const unsigned char {var_name}[] = {{')
        for i, b in enumerate(tflite_bytes):
            if i % 12 == 0:
                f.write('\n  ')
            f.write(f'0x{b:02x}, ')
        f.write(f'\n}};\nconst int {var_name}_len = {len(tflite_bytes)};\n')
    print(f'Header file saved to {header_path}')

tflite_to_header(args.tflite_path, args.header_path)

print('Done!')
