Chapter 03 主题聚焦“数据表示 + 批处理 + 训练机制增强”。本项目对这些内容的实现与抽象如下:
| 概念/知识点 | 书本示例 | 项目实现 / 文件 | 说明 |
|---|---|---|---|
| One-hot / Multi-hot 再次强化 | 词袋 / 类别向量 | dlp.data.vectorization |
统一入口,被后续文本任务 (Ch04) 复用 |
| 批处理(手写生成器) | for i in range(0,n,b) | dlp.data.batching.BatchGenerator |
Numpy 索引实现,可与 tf.data 对照 |
| tf.data.Dataset 管线 | Dataset.from_tensor_slices | fit_basic / fit_supervised 内部使用 |
过渡到更高性能 API |
| 训练循环增强 | 记录多指标 / 验证集 | dlp.training.fit_supervised |
新增 metrics, val loop, history 返回 |
| 指标状态管理 | metric.update_state | fit_supervised 中 _prepare_metrics |
演示字符串到对象的转换 |
| 验证集拆分策略 | 手动切片/validation_split | CLI _prepare_data + validation_split |
分类 vs 回归分支处理 |
| 复现与可重复性 | 随机种子 | --seed + dlp.utils.set_seed |
数据抽样/打乱一致性 |
| 基础日志记录 | 打印 / 画图 | LossHistorySimple (Chapter 05 前置) |
提前埋点方便后续过拟合分析 |
特性:
- 显式 optimizer / loss 注入 (默认 RMSprop + SparseCategoricalCrossentropy)
- 支持 metric 列表(当前示例:accuracy 或自定义 Metric 实例)
- 每 epoch 验证评估 + history 字典输出 (keys: loss, accuracy, val_loss, val_accuracy ...)
最小示例:
import tensorflow as tf
from dlp.models import mnist_mlp
from dlp.training import fit_supervised
from dlp.utils import set_seed
from tensorflow.keras.datasets import mnist
set_seed(42)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28*28).astype("float32")/255
x_test = x_test.reshape(-1, 28*28).astype("float32")/255
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(128)
val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(256)
model = mnist_mlp()
history = fit_supervised(model, train_ds, val_dataset=val_ds, epochs=2, metrics=["accuracy"])
print(history.keys())| 维度 | BatchGenerator | tf.data.Dataset |
|---|---|---|
| 写法复杂度 | 简单 (Numpy 索引) | 略高 (链式调用) |
| 可组合性 | 低 | 高 (map / cache / prefetch) |
| 性能 | 中等 (内存复制) | 更优 (流水线/并行) |
| 教学价值 | 揭示批采样机制 | 连接到工业级输入管线 |
- 用
fit_basic理解“全量前向+反向”循环。 - 引入 BatchGenerator 说明“分批 + 打乱 + epoch”概念。
- 将 BatchGenerator 替换为 tf.data,展示更简洁与可扩展写法。
- 升级到
fit_supervised,观察 history(loss+accuracy)返回格式与 Keras 对照。
- 能解释
update_state与result的关系。 - 能将一段手写 for-loop 数据加载替换为 BatchGenerator 或 tf.data 实现。
- 能使用
fit_supervised追踪训练与验证指标。
| 任务 | 描述 | 价值 |
|---|---|---|
| T20 | 为 fit_supervised 添加进度条 (tqdm) |
可视化迭代进度 |
| T21 | 加入学习率调度 (Warmup / Decay) | 训练稳定性对比 |
| T22 | 添加梯度裁剪选项 | 防止梯度爆炸案例演示 |
| T23 | History 序列化保存 (JSON) | 与回调产物统一 |
- 不直接在
fit_supervised内复刻全部 Keras 功能,避免复杂度过高;强调“教学桥梁”角色。 - metrics 解析仅支持一个简单字符串集合,逐步引导再到“传入 Metric 实例”与更复杂指标 (F1 / RMSE) 已在后续章节实现。
本文件用于巩固 Chapter 03 数据与训练机制的抽象升级,帮助学习者定位各概念在代码中的落点。