主页 > 互联网  > 

Paddlebuild_cinn_pass_test源码阅读(fluid目录下)

Paddlebuild_cinn_pass_test源码阅读(fluid目录下)

代码位置在 paddle\fluid\framework\paddle2cinn\build_cinn_pass_test.cc ,因为paddle CINN和PIR部分依旧在高频更新,所以各位看到的可能和我的不一样

inline bool CheckNodeExisted(const std::unordered_set<Node*>& nodes, const std::string& op_name) { return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) { return node->Name() == op_name; }) != nodes.end(); }

用一个内联函数, 去看一个 unordered_set (一系列节点) 中是否有某个 node 的名字是 op_name,用 std::find_if 去实现, 第三个参数传入的是匿名函数。[&op_name] 闭包被定义在Lambda表达式声明中的方括号[]内. 这个机制允许这些变量被按值或按引用捕获.

函数匿名函数的闭包可以参考这篇文章: blogs /pzhfei/archive/2013/01/14/lambda_expression.html

接下来就是返回名字为 op_name 的 node 数量

inline int CountNode(const std::unordered_set<Node*>& nodes, const std::string& op_name) { return std::count_if( nodes.begin(), nodes.end(), [&op_name](const Node* node) { return node->Name() == op_name; }); }

接下来是返回节点名字是 op_name 的 节点,注意 std::find_if 前面为啥有 * 呢,因为 find_if 返回一个迭代器, *迭代器 可以返回一个 Node*

inline Node* GetNode(const std::unordered_set<Node*>& nodes, const std::string& op_name) { return *std::find_if( nodes.begin(), nodes.end(), [&op_name](const Node* node) { return node->Name().find(op_name) != std::string::npos; }); }

CheckGraphIndependence 内部定义了一个 check_node_ok 匿名函数,匿名函数中 n1 和 n2 都是节点 Node 的指针, ( 说明一下,Paddle PIR之前的节点,节点既有 Op, 也有 Var ) 只有 n1 和 n2 一个为 OP, 一个为 Var 才有可能返回 true;

inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) { auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool { if (n1->IsOp() && !n2->IsVar()) { return false; } if (n1->IsVar() && !n2->IsOp()) { return false; } if (nodes.count(n2) == 0) { return false; } return true; }; for (auto node : nodes) { for (auto in : node->inputs) { if (!check_node_ok(node, in)) { return false; } } for (auto out : node->outputs) { if (!check_node_ok(node, out)) { return false; } } } return true; }

这里需要说明一下,由于 Paddle pir之前 Op 和 Var 都是node, 所以这样定义

var1 -> op1 -> var2 op3-> var3 -> op4

op1的输入是 var1,输出是 var2,而下边那一行是 va3 的输入是 op3,var3 的输出是 op4 , 这样写有点儿诡异,不过确实是这样定义的

所以 CheckGraphIndependence 的用法就是,首先检查是不是 op->var 和 var->op 的关系,其次就是看当前 op/var 在不在当前 Graph 的 unordered_set<Node*> 中

可以看到之后的调用就是将计算图的节点 g->Nodes() 传入 CheckGraphIndependence,如果返回值不为 True 则报错

ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));

这个函数主要是将 kCinnLaunchOp 的 operators::kCompilationKey 属性取出来扔到 compilation_keys这个 vector 中, 目前暂时未知有什么用

// Get compilation_key values std::vector<int64_t> GetCompilationKeys(const Graph& graph) { std::vector<int64_t> compilation_keys; for (auto& node : graph.Nodes()) { if (node->IsOp() && node->Name() == kCinnLaunchOp) { compilation_keys.emplace_back(PADDLE_GET_CONST( int64_t, node->Op()->GetAttr(operators::kCompilationKey))); } } return compilation_keys; }

接下来创建一个CINN子图,创建一个空图 Graph, 之后依次添加 op 和 var

std::unique_ptr<Graph> BuildNoCinnSubgraph() { ProgramDesc prog; auto g = std::make_unique<Graph>(prog); // var1 -- // | --> fake1 --> var3 --> fake2 --> var4 // var2 -- // *Desc 是之后用来创建 OpNode 和 VarNode 的类 OpDesc fake1_op; fake1_op.SetType("fake1"); OpDesc fake2_op; fake2_op.SetType("fake2"); VarDesc var1("var1"); VarDesc var2("var2"); var2.SetPersistable(true); var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); // 之后用 graph 的 Create*Node 来创建对应的 ir::Node ir::Node* fake1 = g->CreateOpNode(&fake1_op); ir::Node* fake2 = g->CreateOpNode(&fake2_op); ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v4 = g->CreateVarNode(&var4); // ----------- 创建完 node 之后, 把 op/var 串起来 // fill op node fake1->inputs = {v1, v2}; fake1->outputs = {v3}; fake2->inputs = {v3}; fake2->outputs = {v4}; // fill variable node v1->outputs = {fake1}; v2->outputs = {fake1}; v3->inputs = {fake1}; v3->outputs = {fake2}; v4->inputs = {fake2}; return g; }

接下来出现第一个单测

TEST(BuildCinnPassTest, NoCinnSubgraph) { auto g = BuildNoCinnSubgraph(); // 调用上边的函数建计算图 auto previous_nodes = g->Nodes(); // 取出计算图的节点 // 创建 pass 这个应该是旧IR的pass auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); // g.get() 返回的是图的指针, g是个 unique_ptr 的智能指针 pass->Apply(g.get()); // After search, origin graph should no change // 注释的意思是, pass search 之后, 原来的计算图不应当修改 ASSERT_EQ(previous_nodes, g->Nodes()); ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // 接下来看计算图是否合法且不依赖其他计算图 // After search, there should be no cinn subgraph ASSERT_TRUE(GetCompilationKeys(*g).empty()); // pass search之后没有 cinn subgraph 子图怎么理解 }

接下来依旧是 BuildAllOpSupportCinnGraph 与上一个建图的函数没啥太大区别

图更加复杂op 的 type 从 fake2 变成了 elementwise_add | mul | relu std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() { ProgramDesc prog; auto g = std::make_unique<Graph>(prog); // v1 -- // | --> mul --> v3 -- // v2 -- | --> add --> v5 --> relu --> v6 // v4 -- OpDesc add_op; add_op.SetType("elementwise_add"); OpDesc mul_op; mul_op.SetType("mul"); OpDesc relu_op; relu_op.SetType("relu"); VarDesc var1("var1"); VarDesc var2("var2"); var2.SetPersistable(true); var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); VarDesc var5("var5"); VarDesc var6("var6"); ir::Node* add = g->CreateOpNode(&add_op); ir::Node* mul = g->CreateOpNode(&mul_op); ir::Node* relu = g->CreateOpNode(&relu_op); ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable); // 创建空节点用意是? ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v4 = g->CreateVarNode(&var4); ir::Node* v5 = g->CreateVarNode(&var5); ir::Node* v6 = g->CreateVarNode(&var6); ir::Node* v7 = g->CreateControlDepVar(); // fill op node mul->inputs = {v0, v1, v2}; mul->outputs = {v3}; add->inputs = {v3, v4}; add->outputs = {v5}; relu->inputs = {v5}; relu->outputs = {v6, v7}; // fill variable node v0->outputs = {mul}; v1->outputs = {mul}; v2->outputs = {mul}; v3->inputs = {mul}; v3->outputs = {add}; v4->outputs = {add}; v5->inputs = {add}; v5->outputs = {relu}; v6->inputs = {relu}; v7->inputs = {relu}; return g; }

上边这个注释有点儿问题:

// v1 -- // | --> mul --> v3 -- // v2 -- | --> add --> v5 --> relu --> v6 // v4 --

应该改成:

// v0 --| // v1 --| // v2 --| --> mul --> v3 --| // --> v4 --| --> add --> v5 --> relu --> v6 // --> v7

接下来的 TEST 和之前的一样,只不过由于图结构变化,pass 之后图结构都变化为 kCinnLaunchOp

TEST(BuildCinnPassTest, AllOpSupportCinn) { auto g = BuildAllOpSupportCinnGraph(); auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); pass->Apply(g.get()); // After search, the graph should as following // v0 --| // v1 --| |--> v6 // v2 --| --> kCinnLaunchOp |--> v7 // v4 --| const auto& nodes = g->Nodes(); ASSERT_EQ(nodes.size(), static_cast<size_t>(7)); // 节点数为 7, 4个输入, 2个输出 和 1 个 Op 节点 ASSERT_TRUE(CheckGraphIndependence(nodes)); // 检测该图是否独立,是否会依赖其他图 // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); // kCinnLaunchOp 是个常量字符串, 检测节点 vector 中有无 kCinnLaunchOp auto* cinn_op = GetNode(nodes, kCinnLaunchOp); auto* v0 = GetNode(nodes, "var0"); auto* v1 = GetNode(nodes, "var1"); // 依次获取对应的 var Node 指针 auto* v2 = GetNode(nodes, "var2"); auto* v4 = GetNode(nodes, "var4"); auto* v6 = GetNode(nodes, "var6"); auto* v7 = GetNode(nodes, Node::kControlDepVarName); // 查看 cinn_op 的输入输出是否与 `v0, v1, v2, v4` 和 `v6, v7` 对应 ASSERT_EQ( std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()), std::unordered_set<Node*>({v0, v1, v2, v4})); ASSERT_EQ(std::unordered_set<Node*>(cinn_op->outputs.begin(), cinn_op->outputs.end()), std::unordered_set<Node*>({v6, v7})); // 查看 var 节点的输入输出是否是 cinn_op ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op})); ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op})); // previous op (mul, add, relu) should all removed // 由于 mul/elementwise_add/relu 被整体合并为 cinn_op 所以图中不应该被搜索到 ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add")); ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); // After search, there should has just one cinn subgraph // feed --> v1 -- // | --> mul --> v3 -- // feed --> v2 -- | --> add --> v5 --> relu --> v6 --> fetch // feed --> v4 -- // 获取编译完毕之后的 key, 之后会根据 key 去取对应的 subgraph auto compilation_keys = GetCompilationKeys(*g); ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1)); // 因为只有一个 kCinnLaunchOp 所以 key 的数量也为 1 auto* cinn_compiler = CinnCompiler::GetInstance(); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); // 根据 key 拿对应的子图 const auto& subnodes = subgraph.Nodes(); // 拿子图的节点set ASSERT_EQ(subnodes.size(), static_cast<size_t>(13)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); // 该 cinn op 就是这三 mul | elementwise_add | relu 的合体 ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_EQ(CountNode(subnodes, "feed"), 3); // 上边注释有 3个feed Op ASSERT_EQ(CountNode(subnodes, "fetch"), 1); // 1 个 fetch Op // 在 kCinnLaunchOp 中有参和无参的 node 都应当有 feed Op // No-parameter input should has feed op auto new_v1 = GetNode(subnodes, "var1"); ASSERT_EQ(new_v1->inputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v1->outputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v1->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v1->outputs[0]->Name(), "mul"); // Parameter input should also have the feed op auto new_v2 = GetNode(subnodes, "var2"); ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v2->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); // kCinnLaunchOp 输出中应当有 fetch Op // output should has fetch op auto new_v6 = GetNode(subnodes, "var6"); ASSERT_EQ(new_v6->inputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v6->outputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v6->inputs[0]->Name(), "relu"); ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch"); }

第一个单测是只有 fake Op 没办法 pass 优化,第二个单测是所有Op 都支持 CINN Pass, 那下一个就是一半是 fake Op,另一半是 只是 CINN Pass 的 OP

std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() { ProgramDesc prog; auto g = std::make_unique<Graph>(prog); // fake1 --> v1 -- // | --> mul --> v3 --> relu --> v4 --> fake2 // v2 -- OpDesc fake1_op; fake1_op.SetType("fake1"); OpDesc mul_op; mul_op.SetType("mul"); OpDesc relu_op; relu_op.SetType("relu"); OpDesc fake2_op; fake2_op.SetType("fake2"); VarDesc var1("var1"); VarDesc var2("var2"); var2.SetPersistable(true); var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); ir::Node* fake1 = g->CreateOpNode(&fake1_op); ir::Node* mul = g->CreateOpNode(&mul_op); ir::Node* relu = g->CreateOpNode(&relu_op); ir::Node* fake2 = g->CreateOpNode(&fake2_op); ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v4 = g->CreateVarNode(&var4); // fill op node fake1->outputs = {v1}; mul->inputs = {v2, v1}; mul->outputs = {v3}; relu->inputs = {v3}; relu->outputs = {v4}; fake2->inputs = {v4}; // fill variable node v2->outputs = {mul}; v1->inputs = {fake1}; v1->outputs = {mul}; v3->inputs = {mul}; v3->outputs = {relu}; v4->inputs = {relu}; v4->outputs = {fake2}; return g; }

上边的函数就是建立了一个这样的一个图

// fake1 --> v1 -- // | --> mul --> v3 --> relu --> v4 --> fake2 // v2 --

通过 cinn pass 之后这个图的节点变成下边儿这样:

// fake1 --> v1 -- // | --> kCinnLaunchOp --> v4 --> fake2 // v2 --

只有一个 kCinnLaunchOp 其子图为,有9个节点

// feed --> v1 -- // | --> mul --> v3 --> relu --> v4 --> fetch // feed --> v2 --

之前的图是单个 cinn op,下一个单测是多个 cinn op 的情况:

std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() { ProgramDesc prog; auto g = std::make_unique<Graph>(prog); // fake1 --> v1 -- // | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3 // v2 -- OpDesc fake1_op; fake1_op.SetType("fake1"); OpDesc mul_op; mul_op.SetType("mul"); OpDesc relu_op; relu_op.SetType("relu"); OpDesc fake2_op; fake2_op.SetType("fake2"); OpDesc fake3_op; fake3_op.SetType("fake3"); VarDesc var1("var1"); VarDesc var2("var2"); var2.SetPersistable(true); var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); VarDesc var5("var5"); ir::Node* fake1 = g->CreateOpNode(&fake1_op); ir::Node* mul = g->CreateOpNode(&mul_op); ir::Node* relu = g->CreateOpNode(&relu_op); ir::Node* fake2 = g->CreateOpNode(&fake2_op); ir::Node* fake3 = g->CreateOpNode(&fake3_op); ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v4 = g->CreateVarNode(&var4); ir::Node* v5 = g->CreateVarNode(&var5); // fill op node fake1->outputs = {v1}; mul->inputs = {v2, v1}; mul->outputs = {v3}; fake2->inputs = {v3}; fake2->outputs = {v4}; relu->inputs = {v4}; relu->outputs = {v5}; fake3->inputs = {v5}; // fill variable node v2->outputs = {mul}; v1->inputs = {fake1}; v1->outputs = {mul}; v3->inputs = {mul}; v3->outputs = {fake2}; v4->inputs = {fake2}; v4->outputs = {relu}; v5->inputs = {relu}; v5->outputs = {fake3}; return g; }

以上代码建立一个这样的图:

// fake1 --> v1 -- // | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3 // v2 --

以 fake2 op 为界,可以建立两个 cinn op pass

// fake1 -> v1 - // | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3 // v2 -

cinn pass 就两句代码:

auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); pass->Apply(g.get());

此处是检验有两个 cinn pass Op 的代码:

// A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2);

最后的编译结果是 cinn pass 之后有两个 子图:

// subgraph1: // feed --> v4 --> relu --> v5 --> fetch // subgraph2: // feed --> v1 -- // | --> mul --> v3 --> fetch // v2 --

BuildGraphWithNoNeedBufferInput 就是建立一个这样的子图:

// fake1 --> v1 -- --> v4 --> relu_grad --> v6 // v2 -- | --> add_grad | // v3 -- --> v5 --> fake2

BuildGraphWithNoNeedBufferInput 与之前不同的是,add_grad_op 使用了设置输入的 API SetInput

OpDesc add_grad_op; add_grad_op.SetType("elementwise_add_grad"); add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"}); add_grad_op.SetInput("X", {"var2"}); add_grad_op.SetInput("Y", {"var3"});

之后的单测写了,no_need_buffer_x 不知道什么意思.

// A new op named kCinnLaunchOp should be added and // its input arguments are set correctly ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1); auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp); ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX), std::vector<std::string>({"var1"})); auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX); ASSERT_EQ(std::unordered_set<std::string>(no_need_buffer_x.begin(), no_need_buffer_x.end()), std::unordered_set<std::string>({"var2", "var3"}));

这里的 no_need_buffer_feeds 什么意思??

ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad")); ASSERT_EQ(CountNode(subnodes, "feed"), 3); ASSERT_EQ(CountNode(subnodes, "fetch"), 2); const auto& no_need_buffer_feeds = subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds); ASSERT_EQ(no_need_buffer_feeds.size(), 2); ASSERT_EQ(no_need_buffer_feeds, std::unordered_set<std::string>({"var2", "var3"})); // check the attributes of variable lists are saved correctly ASSERT_TRUE(subgraph.Has(kInputVars)); EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInputVars), std::vector<std::string>({"var1"})); ASSERT_TRUE(subgraph.Has(kInternalVars)); EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInternalVars), std::vector<std::string>({"var4"})); ASSERT_TRUE(subgraph.Has(kOutputVars)); const auto& output_vars = subgraph.Get<std::vector<std::string>>(kOutputVars); EXPECT_EQ( std::unordered_set<std::string>(output_vars.begin(), output_vars.end()), std::unordered_set<std::string>({"var5", "var6"})); TEST(BuildCinnPassTest, TestSkipGcVars){ auto g = BuildGraphWithOneCinnSubgraph(); // 这里什么意思???? std::unordered_set<std::string> all_skip_gc_vars = {"var1", "var3"}; g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars); auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); pass->Apply(g.get()); // After search, the graph should as following // fake1 --> v1 -- // | --> kCinnLaunchOp --> v4 --> fake2 // v2 -- const auto& nodes = g->Nodes(); ASSERT_EQ(nodes.size(), static_cast<size_t>(7)); // 这里为啥变成了 7 ASSERT_TRUE(CheckGraphIndependence(nodes)); // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); // After search, there should has just one cinn subgraph // Note v3 has fetched because of v3 in kSkipGcVarNames // And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames // feed --> v1 -- // | --> mul --> v3 --> relu --> v4 --> fetch // feed --> v2 -- --> fetch auto compilation_keys = GetCompilationKeys(*g); ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1)); auto* cinn_compiler = CinnCompiler::GetInstance(); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); ASSERT_EQ(subnodes.size(), static_cast<size_t>(10)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_EQ(CountNode(subnodes, "feed"), 2); // var3 and var4 should has fetch op ASSERT_EQ(CountNode(subnodes, "fetch"), 2); }

最后两个 TEST 没看懂,留下问题

标签:

Paddlebuild_cinn_pass_test源码阅读(fluid目录下)由讯客互联互联网栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“Paddlebuild_cinn_pass_test源码阅读(fluid目录下)