/src/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file implements logging infrastructure for extracting features and |
10 | | // rewards for mlgo policy training. |
11 | | // |
12 | | //===----------------------------------------------------------------------===// |
13 | | #include "llvm/Analysis/TensorSpec.h" |
14 | | #include "llvm/Config/config.h" |
15 | | |
16 | | #include "llvm/ADT/Twine.h" |
17 | | #include "llvm/Analysis/Utils/TrainingLogger.h" |
18 | | #include "llvm/Support/CommandLine.h" |
19 | | #include "llvm/Support/Debug.h" |
20 | | #include "llvm/Support/JSON.h" |
21 | | #include "llvm/Support/MemoryBuffer.h" |
22 | | #include "llvm/Support/Path.h" |
23 | | #include "llvm/Support/raw_ostream.h" |
24 | | |
25 | | #include <cassert> |
26 | | #include <numeric> |
27 | | |
28 | | using namespace llvm; |
29 | | |
30 | 0 | void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) { |
31 | 0 | json::OStream JOS(*OS); |
32 | 0 | JOS.object([&]() { |
33 | 0 | JOS.attributeArray("features", [&]() { |
34 | 0 | for (const auto &TS : FeatureSpecs) |
35 | 0 | TS.toJSON(JOS); |
36 | 0 | }); |
37 | 0 | if (IncludeReward) { |
38 | 0 | JOS.attributeBegin("score"); |
39 | 0 | RewardSpec.toJSON(JOS); |
40 | 0 | JOS.attributeEnd(); |
41 | 0 | } |
42 | 0 | if (AdviceSpec.has_value()) { |
43 | 0 | JOS.attributeBegin("advice"); |
44 | 0 | AdviceSpec->toJSON(JOS); |
45 | 0 | JOS.attributeEnd(); |
46 | 0 | } |
47 | 0 | }); |
48 | 0 | *OS << "\n"; |
49 | 0 | } |
50 | | |
51 | 0 | void Logger::switchContext(StringRef Name) { |
52 | 0 | CurrentContext = Name.str(); |
53 | 0 | json::OStream JOS(*OS); |
54 | 0 | JOS.object([&]() { JOS.attribute("context", Name); }); |
55 | 0 | *OS << "\n"; |
56 | 0 | } |
57 | | |
58 | 0 | void Logger::startObservation() { |
59 | 0 | auto I = ObservationIDs.insert({CurrentContext, 0}); |
60 | 0 | size_t NewObservationID = I.second ? 0 : ++I.first->second; |
61 | 0 | json::OStream JOS(*OS); |
62 | 0 | JOS.object([&]() { |
63 | 0 | JOS.attribute("observation", static_cast<int64_t>(NewObservationID)); |
64 | 0 | }); |
65 | 0 | *OS << "\n"; |
66 | 0 | } |
67 | | |
68 | 0 | void Logger::endObservation() { *OS << "\n"; } |
69 | | |
70 | 0 | void Logger::logRewardImpl(const char *RawData) { |
71 | 0 | assert(IncludeReward); |
72 | 0 | json::OStream JOS(*OS); |
73 | 0 | JOS.object([&]() { |
74 | 0 | JOS.attribute("outcome", static_cast<int64_t>( |
75 | 0 | ObservationIDs.find(CurrentContext)->second)); |
76 | 0 | }); |
77 | 0 | *OS << "\n"; |
78 | 0 | writeTensor(RewardSpec, RawData); |
79 | 0 | *OS << "\n"; |
80 | 0 | } |
81 | | |
82 | | Logger::Logger(std::unique_ptr<raw_ostream> OS, |
83 | | const std::vector<TensorSpec> &FeatureSpecs, |
84 | | const TensorSpec &RewardSpec, bool IncludeReward, |
85 | | std::optional<TensorSpec> AdviceSpec) |
86 | | : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), |
87 | 0 | IncludeReward(IncludeReward) { |
88 | 0 | writeHeader(AdviceSpec); |
89 | 0 | } |