#include "riegeli_s3_writer.h"

#include <sstream>

#include "aws/s3/model/CompleteMultipartUploadRequest.h"
#include "aws/s3/model/CompleteMultipartUploadResult.h"
#include "aws/s3/model/CompletedMultipartUpload.h"
#include "aws/s3/model/CompletedPart.h"
#include "aws/s3/model/CreateMultipartUploadRequest.h"
#include "aws/s3/model/CreateMultipartUploadResult.h"
#include "aws/s3/model/UploadPartRequest.h"
#include "aws/s3/model/UploadPartResult.h"
#include "glog/logging.h"

RiegeliS3Writer::RiegeliS3Writer(
    Aws::S3::S3Client* s3_client, const std::string& bucket_name, const std::string& key_path)
    : s3_client_(s3_client), bucket_name_(bucket_name), key_path_(key_path) {
  initialize();
}

RiegeliS3Writer::~RiegeliS3Writer() {
  CHECK(finished_);
}

bool RiegeliS3Writer::WriteInternal(absl::string_view src) {
  LOG(INFO) << "WriteInternal for " << key_path_ << " with " << src.size() << " bytes";
  Aws::S3::Model::UploadPartRequest request;
  request.SetBucket(bucket_name_);
  request.SetKey(key_path_);
  request.SetUploadId(upload_id_);
  request.SetPartNumber(part_number_++);
  request.SetBody(std::make_shared<std::stringstream>(
      std::string(src.data(), src.size()), std::ios::in | std::ios::out | std::ios::binary));
  const Aws::S3::Model::UploadPartOutcome outcome = s3_client_->UploadPart(request);
  CHECK(outcome.IsSuccess()) << "UploadPart failed: " << outcome.GetError().GetMessage();
  etags_.push_back(outcome.GetResult().GetETag());
  return true;
}

bool RiegeliS3Writer::Flush(riegeli::FlushType flush_type) {
  LOG(INFO) << "Flush for " << key_path_;
  if (!PushInternal()) {
    return false;
  }
  switch (flush_type) {
    case riegeli::FlushType::kFromObject:
    case riegeli::FlushType::kFromProcess:
      return true;
    case riegeli::FlushType::kFromMachine:
      finish();
      break;
    default:
      CHECK(false) << "Invalid flush_type";
  }
  return true;
}

void RiegeliS3Writer::initialize() {
  LOG(INFO) << "Initialize for " << key_path_;
  Aws::S3::Model::CreateMultipartUploadRequest request;
  request.SetBucket(bucket_name_);
  request.SetKey(key_path_);
  Aws::S3::Model::CreateMultipartUploadOutcome outcome;
  outcome = s3_client_->CreateMultipartUpload(request);
  CHECK(outcome.IsSuccess()) << "CreateMultipartUpload failed: " << outcome.GetError().GetMessage();
  upload_id_ = outcome.GetResult().GetUploadId();
}

void RiegeliS3Writer::finish() {
  LOG(INFO) << "Finish for " << key_path_;
  std::vector<Aws::S3::Model::CompletedPart> completed_parts;
  for (size_t i = 0; i < etags_.size(); i++) {
    completed_parts.push_back(
        Aws::S3::Model::CompletedPart().WithETag(etags_[i]).WithPartNumber(i));
  }
  Aws::S3::Model::CompleteMultipartUploadRequest request;
  request.SetBucket(bucket_name_);
  request.SetKey(key_path_);
  request.SetUploadId(upload_id_);
  request.SetMultipartUpload(Aws::S3::Model::CompletedMultipartUpload().WithParts(completed_parts));
  const Aws::S3::Model::CompleteMultipartUploadOutcome outcome =
      s3_client_->CompleteMultipartUpload(request);
  CHECK(outcome.IsSuccess()) << "CompleteMultipartUpload failed: "
                             << outcome.GetError().GetMessage();
  finished_ = true;
}

void RiegeliS3Writer::Done() {
  LOG(INFO) << "Done for " << key_path_;
  PushInternal();
  finish();
  Writer::Done();
}