V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
爱意满满的作品展示区。
c0xt30a
V2EX  ›  分享创造

一个小小的深度学习引擎 (c++)

  •  
  •   c0xt30a · 2021-07-02 03:58:16 +08:00 · 2571 次点击
    这是一个创建于 1238 天前的主题,其中的信息可能已经有所发展或是发生改变。

    大致是跟着 tensorflow 1 的流程走了一遍,用 variable, placeholder, operator, constant 来建立一个编译时计算图。正向传播时候从全局拉出一个 session,把 placeholder 绑定到 tensor 上就可以开始了。正向传播完了之后再反向传播一遍所有 variable 的梯度也就出来了。然后用个 optimizer 来更新所有的 variable 一遍就是训练了一次。

    起初只是为了练手,大多数算子写的都是很暴露,少有间接层,几乎所有的细节都可以就地找到,代码也很简洁。后来发现这东西只是拿来自娱自乐意思不大,就尽量往 kera 的 api 上靠,希望能够引起注意。

    目前还是非常 tiny,能做一些简单的 classification 和 GAN 的训练。典型的 classification 代码大致如下:

    #include "../include/ceras.hpp"
    #include <iostream>
    int main()
    {
        using namespace ceras;
        random_generator.seed( 42 );
    
        auto input = Input(); // shape( 28, 28 )
        auto l0 = Reshape({28*28,})( input );
        auto l1 = ReLU( Dense( 512, 28*28 )( l0 ) );
        auto l2 = ReLU( Dense( 256, 512 )( l1 ) );
        auto output = Dense( 10, 256 )( l2 );
        auto m = model( input, output );
    
        m.summary( "./mnist_minimal.dot" );
    
        std::size_t const batch_size = 10;
        float learning_rate = 0.005f;
        auto cm = m.compile( CategoricalCrossentropy(), SGD(batch_size, learning_rate) );
    
        unsigned long epoches = 50;
        int verbose = 1;
        double validation_split = 0.1;
        auto const& [x_training, y_training, x_test, y_test] = dataset::mnist::load_data();
    
        auto history = cm.fit( x_training.as_type<float>()/255.0f, y_training.as_type<float>(), batch_size, epoches, verbose, validation_split );
    
        auto error = cm.evaluate( x_test.as_type<float>()/255.0, y_test.as_type<float>(), batch_size );
    
        std::cout << "\nPrediction error on the test set is " << error << std::endl;
    
        return 0;
    }
    
    

    希望能有人讨论,代码在这里: https://github.com/fengwang/ceras

    7 条回复    2021-07-07 09:45:19 +08:00
    towser
        1
    towser  
       2021-07-02 07:11:41 +08:00
    很酷
    cclin
        2
    cclin  
       2021-07-02 08:03:08 +08:00 via Android
    学习一下,今天的摸鱼材料有了
    shm7
        3
    shm7  
       2021-07-02 12:55:41 +08:00 via iPhone
    为啥不按照 torch 写 sess 这种有毛病不能调的痛
    c0xt30a
        4
    c0xt30a  
    OP
       2021-07-02 18:45:02 +08:00
    @shm7 很早的时候上了 tensorflow 的贼船,因为 keras 太好用了。习惯使然……
    leimao
        5
    leimao  
       2021-07-03 02:29:03 +08:00
    header-only 是痛处,应该去掉。
    c0xt30a
        6
    c0xt30a  
    OP
       2021-07-04 02:37:56 +08:00
    @leimao 为什么是痛处?我为了让它脱离二进制依赖特意花了功夫特意这么设计的。而且还把 computation graph 做成了编译时的。
    275761919
        7
    275761919  
       2021-07-07 09:45:19 +08:00
    太秀了
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   2777 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 24ms · UTC 15:09 · PVG 23:09 · LAX 07:09 · JFK 10:09
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.