【caffe源码】Layer和LayerRegistry

在caffe中,Layer代表一层网络。而LayerRegister则是一个Factory模式,用来获取layer的实例。

首先看layer_factory.hpp
在运行时,使用LayerRegistry<Dtype>::CreateLayer(param);可以获取一个已注册的layer。

如果我们要注册layer。假设我们有一个自己实现的layer。

 template <typename Dtype>
 class MyAwesomeLayer : public Layer<Dtype> {
     // your implementations
 };

那么我们需要去掉最后的layer,使用REGISTER_LAYER_CLASS(MyAwesome);来注册我们的layer。

当然,如果你的layer使用了另一个创建者来产生,例如:

template <typename Dtype>
Layer<Dtype*> GetMyAwesomeLayer(const LayerParameter& param) {
    // your implementation
}

这样的话,我们就需要使用REGISTER_LAYER_CREATOR(MyAwesome, GetMyAwesomeLayer)来进行layer的注册了。

当然,一个layer只能被注册一次。

下面是这两个宏的实现。

#define REGISTER_LAYER_CREATOR(type, creator)                                  \
  static LayerRegisterer<float> g_creator_f_##type(#type, creator<float>);     \
  static LayerRegisterer<double> g_creator_d_##type(#type, creator<double>)    \

#define REGISTER_LAYER_CLASS(type)                                             \
  template <typename Dtype>                                                    \
  shared_ptr<Layer<Dtype> > Creator_##type##Layer(const LayerParameter& param) \
  {                                                                            \
    return shared_ptr<Layer<Dtype> >(new type##Layer<Dtype>(param));           \
  }                                                                            \
  REGISTER_LAYER_CREATOR(type, Creator_##type##Layer)

注:单井号(#)在宏定义中的作用就是 把传递过来的参数当成字符串进行替换
双井号(##)又称连接符,它的作用就是 将参数和前面或后面的子串连接起来,成为一个新的子串。

可以看出,在layer_factory内部,注册一个layer时,若没有creator的话,就会创建一个,然后使用REGISTER_LAYER_CREATOR来处理。

两个类型别名

  typedef shared_ptr<Layer<Dtype> > (*Creator)(const LayerParameter&);
  typedef std::map<string, Creator> CreatorRegistry;

一个creator的指针作为layer的参数,一个string和creator对应的map,作为Registry。
下面定义一个单例模式

  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }

获取当前的CreatorRegistry。

增加一个creator。

  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Layer type " << type << " already registered.";
    registry[type] = creator;
  }

使用layerParam获取creator。

  // Get a layer using a LayerParameter.
  static shared_ptr<Layer<Dtype> > CreateLayer(const LayerParameter& param) {
    if (Caffe::root_solver()) {
      LOG(INFO) << "Creating layer " << param.name();
    }
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type
        << " (known types: " << LayerTypeListString() << ")";
    return registry[type](param);
  }

遍历所有的layer。

  static vector<string> LayerTypeList() {
    CreatorRegistry& registry = Registry();
    vector<string> layer_types;
    for (typename CreatorRegistry::iterator iter = registry.begin();
         iter != registry.end(); ++iter) {
      layer_types.push_back(iter->first);
    }
    return layer_types;
  }

辅助类,用于包装addCreator函数。

template <typename Dtype>
class LayerRegisterer {
 public:
  LayerRegisterer(const string& type,
                  shared_ptr<Layer<Dtype> > (*creator)(const LayerParameter&)) {
    // LOG(INFO) << "Registering layer type: " << type;
    LayerRegistry<Dtype>::AddCreator(type, creator);
  }
};

接下来看看Layer.hpp

其中最重要的就是定义了forward和backward两种操作,分别是前向和反向的传播。

初始化layer

  explicit Layer(const LayerParameter& param)
    : layer_param_(param) {
      // Set phase and copy blobs (if there are any).
      phase_ = param.phase();
      if (layer_param_.blobs_size() > 0) {
        blobs_.resize(layer_param_.blobs_size());
        for (int i = 0; i < layer_param_.blobs_size(); ++i) {
          blobs_[i].reset(new Blob<Dtype>());
          blobs_[i]->FromProto(layer_param_.blobs(i));
        }
      }
    }

void SetUp(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
    CheckBlobCounts(bottom, top);
    LayerSetUp(bottom, top);
    Reshape(bottom, top);
    SetLossWeights(top);
  }

处理相关参数

  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {}

根据底部的shape来改变top的shape。

  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) = 0;