PS-Lite源码分析

Parameter Server中文名称叫做参数服务器,是分布式机器学习框架中用来做参数同步的框架。具体介绍可以参考后面链接,这里主要学习一下其实现。

ps-lite是Paramter Server的实现的一个框架,其中参数处理具体相关策略需自己实现。

Parameter Server框架

Parameter Server包含三种角色:Worker、Server、Scheduler。具体关系如下图:

ps-lite_01.png

  • Worker节点负责计算参数,并发参数push到Server,同时从Serverpull参数回来。
  • Server节点负责管理Worker节点发送来的参数,并“合并”,之后供各个Worker使用。
  • Scheduler几点负责管理Worker节点和Server节点的状态。

PS-Lite实现

总体概览

简单看一下各个类以及它们之间的关系
ps-lite_02.svg

  • Postoffice是全局管理类,单例模式创建。管理当前节点角色、其他节点的连接、心跳信息、配置信息。
  • Van是负责通信的类,是Postoffice的成员。Van中std::unordered_map<int, void*> senders_保存了node_id到连接的映射。Van只是定义了接口,具体实现是依赖ZMQ实现的ZMQVan
  • Customer用来通信,跟踪request和response。每一个连接对应一个Customer实例,连接对方的id和Customer实例的id相同。
  • SimpleApp是一个基类;提供了发送接收int型的head和string型的body消息,以及注册消息处理函数。它有2个派生类。
  • KVServer是SimpleApp的派生类,用来保存key-values数据。
  • KVWorker是SimpleApp的派生类,用来想Server Push/Pull key-value数据。
  • KVPairs封装了Key-Value结构,还包含了一个长度选项。
  • SArray是Shared array,像智能指针一样共享数据,接口类似vector。
  • Node封装了节点的信息,例如角色、ip、端口、是否是恢复节点。
  • Control封装了控制信息,例如命令类型、目的节点、barrier_group的id、签名。
  • Meta封装了元数据,发送者、接受者、时间戳、请求还是相应等。
  • Message是要发送的信息,除了元数据外,还包括发送的数据。

节点角色ID

三种节点,从上图可以看出Scheduler节点只有一个,多个Worker和多个Server可以组成一个Group,因此有WorkerGroup和ServerGroup;还有Worker节点和Server节点。每个节点以及每一个Group都有唯一确定的ID。
Scheduler、ServerGroup、WorkerGroup节点ID确定如下:

1
2
3
static const int kScheduler = 1;
static const int kServerGroup = 2;
static const int kWorkerGroup = 4;

1、2、4的二进制表示分别为:001、010、001。这样可以做Group之间的合并,例如要和ServerGroup和WorkerGroup发信息,只需要destination node id设为2+4=6。
1-7用来表示节点的组合。单个节点的ID从8开始。单个Server和单个Worker节点从自己的rank(0、1、2……)转换到其ID:

1
2
3
4
5
6
static inline int WorkerRankToID(int rank) {
return rank * 2 + 9;
}
static inline int ServerRankToID(int rank) {
return rank * 2 + 8;
}

ID到其rank转换:

1
2
3
static inline int IDtoRank(int id) {
return std::max((id - 8) / 2, 0);
}

Postofficetd::unordered_map<int, std::vector<int>> node_ids_保存了Node/NodeGroup与连接节点集合的对应关系。

消息封装

  • 首先使用了自定义的SArray,Smart Array。共享数据,减少数据拷贝,且提供了类似vector的接口。
  • 元数据Meta使用了Protobuf,进行了数据压缩。
  • 消息分层比较清晰。Node包含节点的角色、id、ip、端口信息;Control包含了命令信息、签名等;Meta是元数据,包含时间戳、发送者、接受者、控制信息等;Message才是发送的信息,包含元数据和发送的数据。
  • 参数有key-value组成,对应KVPairs

通信机制

Scheduler节点管理所有节点的地址。每个节点要知道Scheduler节点的IP、port;启动时绑定一个本地端口,并向Scheduler节点报告。Scheduler节点在每个几点启动后,给节点分配ID,把节点信息通知出去(例如Worker节点要知道Server节点IP和端口,Server节点要知道Worker节点的IP和端口)。节点在建立连接后,才会正式启动。

同步策略

异步工作时,Worker计算参数可能要依赖前面Pull是否完成。如果需要等待某一步操作,可以调用SimpleApp::Wait操作。具体实现是调用了Customer::WaitRequest(),它会跟踪request和response数量是否相同,直到相同才会返回;tracker_类型为std::vector<std::pair<int, int>>,记录了request和response数量,这个数据结构一直增长,会造成内存一直增长。

消息处理流程

每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。
对于Server节点:
1、Van::Receiving()函数是单独一个线程来接收数据。数据接收后,根据不同命令执行不同动作,例如Control::ADD_NODE就是添加节点。如果需要下一步处理,会将消息传递给Customer::Accept函数。
2、Customer::Accept()函数将消息添加到一个队列recv_queue_Customer::Receiving()是一个线程在运行,从队列取消息处理;处理过程中会使用函数对象recv_handle_处理消息,这个函数对象是SimpleApp::Process函数。
3、SimpleApp::Process根据是消息类型(请求or响应,调用用户注册的函数来处理消息,request_handle_response_handle_分别处理请求和响应。

对于Worker节点,上面第3点略有不同。因为Worker都是通过PushPull来通信,而且参数都是key-value对。Pull·参数时,通过KVWorker::Process调用回调函数来处理消息。

调试及启动流程

PS Lite通过环境变量和外界交互。

启动流程:
1、首先启动Scheduler节点。这是要固定好Server和Worker数量。
2、启动Worker或Server节点。启动时连接Scheduler节点,绑定本地端口,并向Scheduler节点注册自己信息。
3、Scheduler等待所有Worker节点都注册后,给其分配id,并把节点信息传送出去。此时Scheduler节点已经准备好。
4、Worker或Server接收到Scheduler传送的信息后,建立对应节点的连接。此时Worker或Server已经准备好。

调试时,通过环境变量来控制调试日志。
PS_VERBOSE=1,会打印连接日志。
PS_VERBOSE=2,会打印所有数据通信日志。

一个例子

参考源码给出的例子。KVPair中一个key对应多个value,具体数量在lens中记录。Server使用key-vector映射存储数据。Server收到的Push数据,只是将对应key的值相加。最后Worker Pull的数据,按照key打印。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <iostream>
#include "ps/ps.h"
using namespace std;
using namespace ps;
template <class Val>
class MyKVServerHandle {
public:
void operator() (const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
size_t n = req_data.keys.size();
KVPairs<Val> res;
if (req_meta.push) { // push
CHECK_EQ(n, req_data.lens.size());
} else { // pull
res.keys = req_data.keys;
res.lens.resize(res.keys.size());
}

size_t cur_idx = 0;
for (size_t i = 0;i < n; ++i) {
Key key = req_data.keys[i];
if(req_meta.push){ //push
int len = req_data.lens[i];
if(store.count(key) == 0){//第一次push,开辟空间
store[key] = vector<Val>(len, 0);
}

for(int idx = 0; idx < len; ++idx){
store[key][idx] += req_data.vals[cur_idx++];
}
}
else{ // pull
res.lens[i] = store[key].size();
for(int idx = 0; idx < res.lens[i]; ++idx){
res.vals.push_back(store[key][idx]);
}
}
}
server->Response(req_meta, res);
}

private:
std::unordered_map<Key, vector<Val>> store;
};
void StartServer() {
if (!IsServer()) return;
cout << "num of workers[" << NumWorkers() << "]" << endl;
cout << "num of servers[" << NumServers() << "]" << endl;
auto server = new KVServer<float>(0);
server->set_request_handle(MyKVServerHandle<float>());
RegisterExitCallback([server](){ delete server; });
}
void RunWorker() {
if (!IsWorker()) return;
cout << "start Worker rank = " << MyRank() << endl;
KVWorker<float> kv(0);
// init
int key_num = 10;
int val_num = 0;
vector<Key> keys(key_num);
vector<int> len(key_num);
for(int i = 0; i < key_num; ++i){
keys[i] = i;
len[i] = i + 1;
val_num += len[i];
}

vector<float> vals(val_num);
for (int i = 0;i < val_num; ++i) {
vals[i] = i / 10;
}
// push
int repeat = 10;
vector<int> ts;
for (int i = 0;i < repeat; ++i) {
ts.push_back(kv.Push(keys, vals, len));
}
for (int t : ts) kv.Wait(t);
// pull
std::vector<float> ret_val;
std::vector<int> ret_len;
kv.Wait(kv.Pull(keys, &ret_val, &ret_len));
CHECK_EQ(keys.size(), ret_len.size());

size_t cur_idx = 0;
for (size_t i = 0;i < keys.size(); ++i) {
std::cout<<MyRank()<<" key ["<<keys[i]<<"] vals [";
for(int idx = 0; idx < ret_len[i]; ++idx){
std::cout<<" "<<ret_val[cur_idx++];
}
std::cout<<"]"<<std::endl;
}
cout << endl;
}
int main(int argc, char* argv[]) {
StartServer();
Start();
RunWorker();
Finalize();
return 0;
}

一个Server,两个Worker

1
./local.sh 1 2 ./test_example
num of workers[2]
num of servers[1]
start Worker rank = 0
start Worker rank = 1
0 key [0] vals [ 0]
0 key [1] vals [ 0 0]
0 key [2] vals [ 0 0 0]
0 key [3] vals [ 0 0 0 0]
0 key [4] vals [ 20 20 20 20 20]
0 key [5] vals [ 20 20 20 20 20 40]
1 key [0] vals [ 0]
0 key [6] vals [ 40 40 40 40 40 40 40]
1 key [1] vals [ 0 0]
1 key [2] vals [ 0 0 0]
0 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [3] vals [ 0 0 0 0]
0 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [4] vals [ 20 20 20 20 20]
0 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]
1 key [5] vals [ 20 20 20 20 20 40]

1 key [6] vals [ 40 40 40 40 40 40 40]
1 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]

参考:

Parameter Server for Distributed Machine Learning
PS-Lite Documents
ps-lite源码剖析

文章目录
  1. 1. Parameter Server框架
  2. 2. PS-Lite实现
    1. 2.1. 总体概览
    2. 2.2. 节点角色ID
    3. 2.3. 消息封装
    4. 2.4. 通信机制
    5. 2.5. 同步策略
    6. 2.6. 消息处理流程
    7. 2.7. 调试及启动流程
    8. 2.8. 一个例子
  3. 3. 参考:
,
#add by kangyabing