Skip to content

Latest commit

 

History

History
74 lines (63 loc) · 4.11 KB

File metadata and controls

74 lines (63 loc) · 4.11 KB

Chapter 03 Integration: Data Representation & Training Pipeline

Chapter 03 主题聚焦“数据表示 + 批处理 + 训练机制增强”。本项目对这些内容的实现与抽象如下:

概念 ↔ 代码映射表

概念/知识点 书本示例 项目实现 / 文件 说明
One-hot / Multi-hot 再次强化 词袋 / 类别向量 dlp.data.vectorization 统一入口,被后续文本任务 (Ch04) 复用
批处理(手写生成器) for i in range(0,n,b) dlp.data.batching.BatchGenerator Numpy 索引实现,可与 tf.data 对照
tf.data.Dataset 管线 Dataset.from_tensor_slices fit_basic / fit_supervised 内部使用 过渡到更高性能 API
训练循环增强 记录多指标 / 验证集 dlp.training.fit_supervised 新增 metrics, val loop, history 返回
指标状态管理 metric.update_state fit_supervised_prepare_metrics 演示字符串到对象的转换
验证集拆分策略 手动切片/validation_split CLI _prepare_data + validation_split 分类 vs 回归分支处理
复现与可重复性 随机种子 --seed + dlp.utils.set_seed 数据抽样/打乱一致性
基础日志记录 打印 / 画图 LossHistorySimple (Chapter 05 前置) 提前埋点方便后续过拟合分析

增强训练循环 fit_supervised

特性:

  • 显式 optimizer / loss 注入 (默认 RMSprop + SparseCategoricalCrossentropy)
  • 支持 metric 列表(当前示例:accuracy 或自定义 Metric 实例)
  • 每 epoch 验证评估 + history 字典输出 (keys: loss, accuracy, val_loss, val_accuracy ...)

最小示例:

import tensorflow as tf
from dlp.models import mnist_mlp
from dlp.training import fit_supervised
from dlp.utils import set_seed
from tensorflow.keras.datasets import mnist

set_seed(42)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28*28).astype("float32")/255
x_test  = x_test.reshape(-1, 28*28).astype("float32")/255
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(128)
val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(256)
model = mnist_mlp()
history = fit_supervised(model, train_ds, val_dataset=val_ds, epochs=2, metrics=["accuracy"])
print(history.keys())

BatchGenerator vs tf.data 对照

维度 BatchGenerator tf.data.Dataset
写法复杂度 简单 (Numpy 索引) 略高 (链式调用)
可组合性 高 (map / cache / prefetch)
性能 中等 (内存复制) 更优 (流水线/并行)
教学价值 揭示批采样机制 连接到工业级输入管线

过渡路径 (Chapter 02 -> 03)

  1. fit_basic 理解“全量前向+反向”循环。
  2. 引入 BatchGenerator 说明“分批 + 打乱 + epoch”概念。
  3. 将 BatchGenerator 替换为 tf.data,展示更简洁与可扩展写法。
  4. 升级到 fit_supervised,观察 history(loss+accuracy)返回格式与 Keras 对照。

验收清单 (完成标志)

  • 能解释 update_stateresult 的关系。
  • 能将一段手写 for-loop 数据加载替换为 BatchGenerator 或 tf.data 实现。
  • 能使用 fit_supervised 追踪训练与验证指标。

后续扩展建议

任务 描述 价值
T20 fit_supervised 添加进度条 (tqdm) 可视化迭代进度
T21 加入学习率调度 (Warmup / Decay) 训练稳定性对比
T22 添加梯度裁剪选项 防止梯度爆炸案例演示
T23 History 序列化保存 (JSON) 与回调产物统一

设计取舍

  • 不直接在 fit_supervised 内复刻全部 Keras 功能,避免复杂度过高;强调“教学桥梁”角色。
  • metrics 解析仅支持一个简单字符串集合,逐步引导再到“传入 Metric 实例”与更复杂指标 (F1 / RMSE) 已在后续章节实现。

本文件用于巩固 Chapter 03 数据与训练机制的抽象升级,帮助学习者定位各概念在代码中的落点。