/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.msq.statistics;

import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import org.apache.datasketches.quantiles.ItemsSketch;
import org.apache.datasketches.quantiles.ItemsUnion;
import org.apache.datasketches.quantilescommon.QuantileSearchCriteria;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.IAE;

import javax.annotation.Nullable;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.NoSuchElementException;

/**
 * A key collector that is used when not aggregating. It uses a quantiles sketch to track keys.
 * <br>
 * The collector maintains the averageKeyLength for all keys added through {@link #add(RowKey, long)} or
 * {@link #addAll(QuantilesSketchKeyCollector)}. The average is calculated as a running average and accounts for
 * weight of the key added. The averageKeyLength is assumed to be unaffected by {@link #downSample()}.
 */
public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector>
{
  private final Comparator<byte[]> comparator;
  private ItemsSketch<byte[]> sketch;
  private double averageKeyLength;

  QuantilesSketchKeyCollector(
      final Comparator<byte[]> comparator,
      @Nullable final ItemsSketch<byte[]> sketch,
      double averageKeyLength
  )
  {
    this.comparator = comparator;
    this.sketch = sketch;
    this.averageKeyLength = averageKeyLength;
  }

  @Override
  public void add(RowKey key, long weight)
  {
    double estimatedTotalSketchSizeInBytes = averageKeyLength * sketch.getN();
    // The key is added "weight" times to the sketch, we can update the total weight directly.
    estimatedTotalSketchSizeInBytes += key.estimatedObjectSizeBytes() * weight;
    for (int i = 0; i < weight; i++) {
      // Add the same key multiple times to make it "heavier".
      sketch.update(key.array());
    }
    averageKeyLength = (estimatedTotalSketchSizeInBytes / sketch.getN());
  }

  @Override
  public void addAll(QuantilesSketchKeyCollector other)
  {
    final ItemsUnion<byte[]> union = ItemsUnion.getInstance(
        byte[].class,
        Math.max(sketch.getK(), other.sketch.getK()),
        comparator
    );

    double sketchBytesCount = averageKeyLength * sketch.getN();
    double otherBytesCount = other.averageKeyLength * other.getSketch().getN();
    averageKeyLength = ((sketchBytesCount + otherBytesCount) / (sketch.getN() + other.sketch.getN()));

    if (!sketch.isEmpty()) {
      union.union(sketch);
    }
    union.union(other.sketch);
    sketch = union.getResultAndReset();
  }

  @Override
  public boolean isEmpty()
  {
    return sketch.isEmpty();
  }

  @Override
  public long estimatedTotalWeight()
  {
    return sketch.getN();
  }

  @Override
  public long estimatedRetainedBytes()
  {
    return Math.round(averageKeyLength * estimatedRetainedKeys());
  }

  @Override
  public int estimatedRetainedKeys()
  {
    return sketch.getNumRetained();
  }

  @Override
  public boolean downSample()
  {
    if (sketch.getN() <= 1) {
      return true;
    } else if (sketch.getK() == 2) {
      return false;
    } else {
      sketch = sketch.downSample(sketch.getK() / 2);
      return true;
    }
  }

  @Override
  public RowKey minKey()
  {
    if (sketch.isEmpty()) {
      throw new NoSuchElementException();
    }
    return RowKey.wrap(sketch.getMinItem());
  }

  @Override
  public ClusterByPartitions generatePartitionsWithTargetWeight(final long targetWeight)
  {
    if (targetWeight <= 0) {
      throw new IAE("targetPartitionWeight must be positive, but was [%d]", targetWeight);
    }

    if (sketch.getN() == 0) {
      return ClusterByPartitions.oneUniversalPartition();
    }

    final int numPartitions = Ints.checkedCast(LongMath.divide(sketch.getN(), targetWeight, RoundingMode.CEILING));

    final byte[][] quantiles = (sketch.getPartitionBoundaries(numPartitions, QuantileSearchCriteria.EXCLUSIVE)).boundaries;
    final List<ClusterByPartition> partitions = new ArrayList<>();

    for (int i = 0; i < numPartitions; i++) {
      final boolean isFinalPartition = i == numPartitions - 1;

      if (isFinalPartition) {
        partitions.add(new ClusterByPartition(RowKey.wrap(quantiles[i]), null));
      } else {
        final int cmp = comparator.compare(quantiles[i], quantiles[i + 1]);
        if (cmp < 0) {
          // Skip partitions where start == end.
          // I don't think start can be greater than end, but if that happens, skip them too!
          final ClusterByPartition partition = new ClusterByPartition(RowKey.wrap(quantiles[i]), RowKey.wrap(quantiles[i + 1]));
          partitions.add(partition);
        }
      }
    }

    return new ClusterByPartitions(partitions);
  }

  @Override
  public int sketchAccuracyFactor()
  {
    return sketch.getK();
  }

  /**
   * Retrieves the backing sketch. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}.
   */
  ItemsSketch<byte[]> getSketch()
  {
    return sketch;
  }

  /**
   * Retrieves the average key length. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}.
   */
  double getAverageKeyLength()
  {
    return averageKeyLength;
  }

  @Override
  public String toString()
  {
    return "QuantilesSketchKeyCollector{" +
           "sketch=ItemsSketch{N=" + sketch.getN() +
           ", K=" + sketch.getK() +
           ", retainedKeys=" + sketch.getNumRetained() +
           "}, averageKeyLength=" + averageKeyLength +
           '}';
  }
}
