#include "store_ops.h"

#include "caffe2/core/blob_serialization.h"

namespace caffe2 {

constexpr auto kBlobName = "blob_name";
constexpr auto kAddValue = "add_value";

StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws)
    : Operator<CPUContext>(operator_def, ws),
      blobName_(
          GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) {
}

bool StoreSetOp::RunOnDevice() {
  // Serialize and pass to store
  auto* handler =
      OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
  handler->set(blobName_, SerializeBlob(InputBlob(DATA), blobName_));
  return true;
}

REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp);
OPERATOR_SCHEMA(StoreSet)
    .NumInputs(2)
    .NumOutputs(0)
    .SetDoc(R"DOC(
Set a blob in a store. The key is the input blob's name and the value
is the data in that blob. The key can be overridden by specifying the
'blob_name' argument.
)DOC")
    .Arg("blob_name", "alternative key for the blob (optional)")
    .Input(0, "handler", "unique_ptr<StoreHandler>")
    .Input(1, "data", "data blob");

StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws)
    : Operator<CPUContext>(operator_def, ws),
      blobName_(GetSingleArgument<std::string>(
          kBlobName,
          operator_def.output(DATA))) {}

bool StoreGetOp::RunOnDevice() {
  // Get from store and deserialize
  auto* handler =
      OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
  DeserializeBlob(handler->get(blobName_), OperatorBase::Outputs()[DATA]);
  return true;
}

REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp);
OPERATOR_SCHEMA(StoreGet)
    .NumInputs(1)
    .NumOutputs(1)
    .SetDoc(R"DOC(
Get a blob from a store. The key is the output blob's name. The key
can be overridden by specifying the 'blob_name' argument.
)DOC")
    .Arg("blob_name", "alternative key for the blob (optional)")
    .Input(0, "handler", "unique_ptr<StoreHandler>")
    .Output(0, "data", "data blob");

StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws)
    : Operator<CPUContext>(operator_def, ws),
      blobName_(GetSingleArgument<std::string>(kBlobName, "")),
      addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) {
  CAFFE_ENFORCE(HasArgument(kBlobName));
}

bool StoreAddOp::RunOnDevice() {
  auto* handler =
      OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
  Output(VALUE)->Resize(1);
  Output(VALUE)->mutable_data<int64_t>()[0] =
      handler->add(blobName_, addValue_);
  return true;
}

REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp);
OPERATOR_SCHEMA(StoreAdd)
    .NumInputs(1)
    .NumOutputs(1)
    .SetDoc(R"DOC(
Add a value to a remote counter. If the key is not set, the store
initializes it to 0 and then performs the add operation. The operation
returns the resulting counter value.
)DOC")
    .Arg("blob_name", "key of the counter (required)")
    .Arg("add_value", "value that is added (optional, default: 1)")
    .Input(0, "handler", "unique_ptr<StoreHandler>")
    .Output(0, "value", "the current value of the counter");

StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws)
    : Operator<CPUContext>(operator_def, ws),
      blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {}

bool StoreWaitOp::RunOnDevice() {
  auto* handler =
      OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
  if (InputSize() == 2 && Input(1).IsType<std::string>()) {
    CAFFE_ENFORCE(
        blobNames_.empty(), "cannot specify both argument and input blob");
    std::vector<std::string> blobNames;
    auto* namesPtr = Input(1).data<std::string>();
    for (int i = 0; i < Input(1).size(); ++i) {
      // NOLINTNEXTLINE(performance-inefficient-vector-operation)
      blobNames.push_back(namesPtr[i]);
    }
    handler->wait(blobNames);
  } else {
    handler->wait(blobNames_);
  }
  return true;
}

REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp);
OPERATOR_SCHEMA(StoreWait)
    .NumInputs(1, 2)
    .NumOutputs(0)
    .SetDoc(R"DOC(
Wait for the specified blob names to be set. The blob names can be passed
either as an input blob with blob names or as an argument.
)DOC")
    .Arg("blob_names", "names of the blobs to wait for (optional)")
    .Input(0, "handler", "unique_ptr<StoreHandler>")
    .Input(1, "names", "names of the blobs to wait for (optional)");
} // namespace caffe2
