深度学习之Iris Flower智能分类

posted at 2024.10.22 12:59 by Administrator

Iris Flower鸢尾花分类任务是一个经典的深度学习项目,也是监督式学习(supervised learning)项目。鸢尾花分为三个不同的品种,即Setosa(山鸢尾)、Versicolor(变色鸢尾)和Virginica(维吉尼亚鸢尾)。该任务的目标是通过花卉的特征判断出其品种。这是一个多类别分类问题,通常用于演示和测试分类算法的性能。

一、一般流程

首先要获取训练用的数据,即训练集。在机器学习中,数据可以从多种渠道获得,比如从硬盘中读取。从网络上下载。直接通过程序生成或简单的硬编码。这里采用最后一种方法;其次,将数据转换为张量。让他们能够输入模型中。下一步就是创建模型。这与函数的概念类似,相当于设计一个可训练函数。能够将输入数据映射到预测目标在本例中,输入数据和预测目标都是数字。一旦准备好模型和数据后,就可以开始训练模型。并查看他在训练过程中生成的度量指标报告在这一切完成后,就可以用训练好的模型来预测未曾出现的数据,同时评估模型的准确率。

      二、数据集

鸢尾花分类任务使用的数据集通常是著名的鸢尾花数据集(Iris dataset)。该数据集包含了150个鸢尾花样本,每个样本的第一列为花萼长度值,第二列为花萼宽度值,第三列为花瓣长度值,第四列为花瓣宽度值,第五列对应是种类(三类鸢尾花分别用012表示)。

例如: [5.1, 3.5, 1.4, 0.2, 2] 

    三、架构

本例使用Javascript语言,采用的是TensorFlow.js架构。这个架构包含以下特性:

支持训练和推断;支持Web浏览器和node.js两种环境;能够利用GPU加速;支持用JavaScript定义神经网络模型架构;支持模型的序列化和反序列化;支持与Python深度学习框架间的双向模型格式转换;兼容Python深度学习框架使用的API;内置数据获取和和可视化所需的API

   四、数据处理

利用convertToTensors(data, targets, testSplit)函数将原始数据集分为训练数据和测试数据,85%的数据用作训练集,15%用作测试集。训练数据用于构建模型,测试数据用于评估模型性能。

getIrisData(testSplit)函数具体执行数据的读取。

为了使得这种非连续关系的数据输出,取值不具有偏序性,并且到原点的距离是相等的,本例采用独热编码One-Hot来表示鸢尾花的类别。每种类别用一个一维向量来表示,其中的三个元素分别对应三个类别,属于哪个类别对应的元素就是1,其余元素为0

所谓One-Hot独热,就是指在一个向量中只有一个元素是1,也就是热的,其他元素都是0。显然采用独热编码,需要占用更多的空间,但是它能够更加合理地表示数据之间的关系,它将一维空间中三个标量的点扩展到了三维空间中,其中的每一个点到原点的距离都是相等的,采用独热编码可以有效的避免学习过程中的偏差。在机器学习中通常会将离散的特征以及多分类问题中的类别标签,采用独热编码的方式来表示。除此之外,还有独冷编码,它和独热编码相反,就是向量中只有一个元素为0,其他元素都为1 

    五、定义模型

Iris Flower分类问题的多层神经网络模型代码:

   const model = tf.sequential();
  model.add(tf.layers.dense(
      {units:10, activation:'sigmoid', inputShape: [xTrain.shape[1]]}));
  model.add(tf.layers.dense({units:3, activation:'softmax'}));
  model.summary();
 
  const optimizer tf.train.adam(params.learningRate);
  model.compile({
    optimizer:optimizer,
    loss:'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

定义的模型的拓扑结构报告如下:

 

从上面的报告可见,这是一个非常简单的模型,它的权重参数也相对较少。第二个密集层的输出形状[null,3]对应分数目标的one-hot编码。

最后一层使用的激活函数是归一化指数函数,即softmax函数,它首先把输入向量中的每个元素转换为对应的自然常数e的指数,进行指数运算,这样就显着的拉开了它们之间的差距,使得大数变得更大,小数更小,把它们的和作为分母对输出进行归一化,这样就保证了 它们的值都在0~1之间,它们的和等于1。例如当softmax()函数的输入是123时,指数运算的结果是2.71.320,显然他们之间的差距明显拉大了。在归一化之后输出是0.090.240.67

之所以称为softmax(),可以理解为以更加soft的方式标记出最大的数。这种多分类模型也称为softmax回归,即 Y=softmax(W^TX)。和逻辑回归一样,softmax回归也是一种广义线性回归,用来完成分类任务。

    六、模型编译

  模型编译代码如下:

  const optimizer tf.train.adam(params.learningRate);
  model.compile({
    optimizer:optimizer,
    loss:'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

   这里将’adam’指定为优化器,旨在通过使用倍增因子(multiplicationfactor)来解决sgd的短板,倍增因子可以随着梯度历史智能变化。除此之外,他还会对不同的模型权重参数使用不同的倍增因子。因此,对很多不同类型的深度学习模型而言,adam通常会带来更好的收敛。并且与sgd相比,adam对学习率选择的依赖更小,这就是它流行的原因。

同时,设计了一个分类交叉熵(categorical cross entropy,即二元交叉熵对超出两种类型的泛化形式。交叉熵做损失函数模型能有更好的收敛。这是因为,它是对数函数,在接近上边界的时候,其仍然可以保持在高梯度状态,或者说传递的梯度值更均匀,更合适(相对于均方误差),模型的收敛速度不会受到影响。

    七、模型训练

Iris Flower鸢尾花分类任务主界面控制面板是这样的:

你可以在通过此控制面板在浏览器上设置训练轮次、学习率,从零开始训练模型,可以保存模型到当地,转入当地保存的模型或预训练模型,也可以删除当地模型。

从训练的轮次来看,训练400轮比训练40轮损失要小很多,准确度显著提高。 

     八、实测经过训练的模型

           Javascript代码:

async function predictOnManualInput(model) {
  if (model == null) {
    ui.setManualInputWinnerMessage('ERROR: Please load or train model first.');
    return;
  }
  tf.tidy(() => {
   
    const inputData ui.getManualInputData();
    const input tf.tensor2d([inputData], [1, 4]);
    const predictOut model.predict(input);
    const logits Array.from(predictOut.dataSync());
    const winner data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
    ui.setManualInputWinnerMessage(winner);
    ui.renderLogitsForManualInput(logits);
  });

效果图:

     九、本例Javascript源代码 

                  iris.zip (37.49 kb)

     十、部分名词理解

深度学习(DLDeep Learning):基于深层神经网络模型和方法的机器学习,具有自动提取特征的能力。

监督式学习(supervised learning):用带标签的样例逐步减少模型转出误差的机器学习。

TensorFlow.js 一个用于使用 JavaScript 进行机器学习开发的库.

模型:可进行计算的方式和规则。一个训练好的机器学习模型,可以用于预测或分类。

层:构建模型时,可以添加不同的Layer来处理数据。

数据预处理: 将输入数据转换为模型可以理解的形式的过程。

训练:使用训练数据调整模型权重的过程。

损失函数:评估模型预测质量的指标。

优化器:更新模型权重以减少损失的算法。

拟合:使模型权重尽可能准确地近似真实函数参数的过程。

学习率(Learning Rate):每次迭代中成本函数中最小化的量。

激活函数(Activation Function):就是在人工神经网络的神经元上运行的函数,负责将神经元的输入映射到输出端。

交叉熵(Cross EntropyShannon信息论中主要用于度量两个概率分布间的差异性信息。

 

 

 

 

 

 

 

 

 

 

 

 

iris.zip (37.49 kb)

Tags: , , , , , , , , ,

IT技术

添加评论

  Country flag

biuquote
  • 评论
  • 在线预览
Loading