博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习笔记(4):多类逻辑回归-使用gluton
阅读量:7030 次
发布时间:2019-06-28

本文共 3194 字,大约阅读时间需要 10 分钟。

接上一篇继续,这次改用gluton来实现关键处理,原文见 ,代码如下:

import matplotlib.pyplot as pltimport mxnet as mxfrom mxnet import gluonfrom mxnet import ndarray as ndfrom mxnet import autograddef transform(data, label):    return data.astype('float32')/255, label.astype('float32')mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)def show_images(images):    n = images.shape[0]    _, figs = plt.subplots(1, n, figsize=(15, 15))    for i in range(n):        figs[i].imshow(images[i].reshape((28, 28)).asnumpy())        figs[i].axes.get_xaxis().set_visible(False)        figs[i].axes.get_yaxis().set_visible(False)    plt.show()def get_text_labels(label):    text_labels = [        'T 恤', '长 裤', '套头衫', '裙 子', '外 套',        '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'    ]    return [text_labels[int(i)] for i in label]data, label = mnist_train[0:10]print('example shape: ', data.shape, 'label:', label)show_images(data)print(get_text_labels(label))batch_size = 256train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)num_inputs = 784num_outputs = 10W = nd.random_normal(shape=(num_inputs, num_outputs))b = nd.random_normal(shape=num_outputs)params = [W, b]for param in params:    param.attach_grad()def accuracy(output, label):    return nd.mean(output.argmax(axis=1) == label).asscalar()def _get_batch(batch):    if isinstance(batch, mx.io.DataBatch):        data = batch.data[0]        label = batch.label[0]    else:        data, label = batch    return data, labeldef evaluate_accuracy(data_iterator, net):    acc = 0.    if isinstance(data_iterator, mx.io.MXDataIter):        data_iterator.reset()    for i, batch in enumerate(data_iterator):        data, label = _get_batch(batch)        output = net(data)        acc += accuracy(output, label)    return acc / (i+1)#使用gluon定义计算模型net = gluon.nn.Sequential()with net.name_scope():    net.add(gluon.nn.Flatten())    net.add(gluon.nn.Dense(10))net.initialize()#损失函数(使用交叉熵函数)softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()#使用梯度下降法生成训练器,并设置学习率为0.1trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})for epoch in range(5):    train_loss = 0.    train_acc = 0.    for data, label in train_data:        with autograd.record():            output = net(data)            #计算损失            loss = softmax_cross_entropy(output, label)         loss.backward()        #使用sgd的trainer继续向前"走一步"        trainer.step(batch_size)                train_loss += nd.mean(loss).asscalar()        train_acc += accuracy(output, label)    test_acc = evaluate_accuracy(test_data, net)    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (        epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))data, label = mnist_test[0:10]show_images(data)print('true labels')print(get_text_labels(label))predicted_labels = net(data).argmax(axis=1)print('predicted labels')print(get_text_labels(predicted_labels.asnumpy()))

相对上一版原始手动方法,使用gluon修改的地方都加了注释,不多解释。运行效果如下:

相对可以发现,几乎相同的参数,但是准确度有所提升,从0.7几上升到0.8几,10个里错误的预测数从4个下降到3个,说明gluon在一些细节上做了更好的优化。关于优化的细节,

转载地址:http://ehrxl.baihongyu.com/

你可能感兴趣的文章
shader 讲解的第二天 把兰伯特模型改成半兰泊特模型 函数图形绘制工具
查看>>
python3.5安装Numpy、mayploylib、opencv等额外库
查看>>
优雅绝妙的Javascript跨域问题解决方案
查看>>
Java 接口技术 Interface
查看>>
函数草稿
查看>>
织梦系统学习:文章页当前位置的写法(自认对SEO有用)
查看>>
PHP经验——PHPDoc PHP注释的标准文档(翻译自Wiki)
查看>>
vue input输入框长度限制
查看>>
深入理解Java虚拟机(类加载机制)
查看>>
在500jsp错误页面获取错误信息
查看>>
iOS-CALayer遮罩效果
查看>>
为什么需要版本管理
查看>>
五、Dart 关键字
查看>>
React Native学习笔记(一)附视频教学
查看>>
记Promise得一些API
查看>>
javascript事件之调整大小(resize)事件
查看>>
20145234黄斐《Java程序设计》第六周学习总结
查看>>
【CLRS】《算法导论》读书笔记(四):栈(Stack)、队列(Queue)和链表(Linked List)...
查看>>
hibernate 和 mybatis区别
查看>>
互联网广告综述之点击率特征工程
查看>>