Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements logging infrastructure for extracting features and
10
// rewards for mlgo policy training.
11
//
12
//===----------------------------------------------------------------------===//
13
#include "llvm/Analysis/TensorSpec.h"
14
#include "llvm/Config/config.h"
15
16
#include "llvm/ADT/Twine.h"
17
#include "llvm/Analysis/Utils/TrainingLogger.h"
18
#include "llvm/Support/CommandLine.h"
19
#include "llvm/Support/Debug.h"
20
#include "llvm/Support/JSON.h"
21
#include "llvm/Support/MemoryBuffer.h"
22
#include "llvm/Support/Path.h"
23
#include "llvm/Support/raw_ostream.h"
24
25
#include <cassert>
26
#include <numeric>
27
28
using namespace llvm;
29
30
0
void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
31
0
  json::OStream JOS(*OS);
32
0
  JOS.object([&]() {
33
0
    JOS.attributeArray("features", [&]() {
34
0
      for (const auto &TS : FeatureSpecs)
35
0
        TS.toJSON(JOS);
36
0
    });
37
0
    if (IncludeReward) {
38
0
      JOS.attributeBegin("score");
39
0
      RewardSpec.toJSON(JOS);
40
0
      JOS.attributeEnd();
41
0
    }
42
0
    if (AdviceSpec.has_value()) {
43
0
      JOS.attributeBegin("advice");
44
0
      AdviceSpec->toJSON(JOS);
45
0
      JOS.attributeEnd();
46
0
    }
47
0
  });
48
0
  *OS << "\n";
49
0
}
50
51
0
void Logger::switchContext(StringRef Name) {
52
0
  CurrentContext = Name.str();
53
0
  json::OStream JOS(*OS);
54
0
  JOS.object([&]() { JOS.attribute("context", Name); });
55
0
  *OS << "\n";
56
0
}
57
58
0
void Logger::startObservation() {
59
0
  auto I = ObservationIDs.insert({CurrentContext, 0});
60
0
  size_t NewObservationID = I.second ? 0 : ++I.first->second;
61
0
  json::OStream JOS(*OS);
62
0
  JOS.object([&]() {
63
0
    JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64
0
  });
65
0
  *OS << "\n";
66
0
}
67
68
0
void Logger::endObservation() { *OS << "\n"; }
69
70
0
void Logger::logRewardImpl(const char *RawData) {
71
0
  assert(IncludeReward);
72
0
  json::OStream JOS(*OS);
73
0
  JOS.object([&]() {
74
0
    JOS.attribute("outcome", static_cast<int64_t>(
75
0
                                 ObservationIDs.find(CurrentContext)->second));
76
0
  });
77
0
  *OS << "\n";
78
0
  writeTensor(RewardSpec, RawData);
79
0
  *OS << "\n";
80
0
}
81
82
Logger::Logger(std::unique_ptr<raw_ostream> OS,
83
               const std::vector<TensorSpec> &FeatureSpecs,
84
               const TensorSpec &RewardSpec, bool IncludeReward,
85
               std::optional<TensorSpec> AdviceSpec)
86
    : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
87
0
      IncludeReward(IncludeReward) {
88
0
  writeHeader(AdviceSpec);
89
0
}