Program Listing for File model_sampler.h

Return to documentation for file (src/model_sampler.h)

#ifndef SHARPSAT_SOLUTION_RECIPE_H
#define SHARPSAT_SOLUTION_RECIPE_H

#include <gmpxx.h>

#include <random>
#include <fstream>
#include <string>
#include <vector>
#include <list>
#include <algorithm>

#include "structures.h"
#include "component_types/component.h"
#include "statistics.h"
#include "alt_component_analyzer.h"
#include "solver_config.h"
#include "cached_assignment.h"
#include "rand_distributions.h"

typedef std::vector<uint32_t> AssignmentContainer;

class SampleAssignment {
  friend class SamplesManager;
 private:
  explicit SampleAssignment(SampleSize sample_count, const SampleAssignment &other) :
      sample_count_(sample_count), assn_(other.assn_), num_vars_set_(other.num_vars_set_),
      emancipated_vars_(other.emancipated_vars_), cache_comp_ids_(other.cache_comp_ids_) { }
  SampleSize sample_count_;
  AssignmentContainer assn_;
  VariableIndex num_vars_set_ = 0;
  std::vector<VariableIndex> remaining_vars_;
  std::vector<VariableIndex> emancipated_vars_;
  std::vector<CacheEntryID> cache_comp_ids_;
  std::string comp_cache_key_;
  static VariableIndex num_var_;
  static VariableIndex var_vec_len_;
  static const unsigned int size_of_word_ = sizeof(typename AssignmentContainer::value_type)
                                            * BITS_PER_BYTE;
  static const VariableIndex bits_per_var_ = 2;
  static const AssignmentContainer::value_type var_mask_ = 0x3;
  bool VerifyStitchingCompatibility(const SampleAssignment &other) const {
    assert(assn_.size() == other.assn_.size());
    assert(sample_count_ <= other.sample_count_);  // Less than or equal for simplified code
                                                  // when stitching the SamplesManager objects

    AssignmentEncoding this_val, other_val;
    for (VariableIndex i = FIRST_VAR; i <= num_var_; i++) {
      this_val = this->var_assignment(i);
      other_val = other.var_assignment(i);
      if (this_val != other_val && this_val != ASSN_U && other_val != ASSN_U) {
        std::cerr << "Stitching error on variable #" << i << std::endl
                  << "Implicit Value = " << this_val << ". Other Value = "
                  << other_val << std::endl;
        return false;
      }
    }
    return true;
  }
  inline static void calculateWordAndBitNumbers(const VariableIndex &var_num,
                                                VariableIndex &word_num, VariableIndex &bit_num) {
    assert(var_num >= FIRST_VAR && var_num <= num_var_);
    word_num = (var_num * bits_per_var_) / size_of_word_;
    assert(word_num >= 0 && word_num < var_vec_len_);
    bit_num = (var_num * bits_per_var_) % size_of_word_;
  }
  const bool VerifyEmancipatedVars() const {
    if (emancipated_vars_.empty())
      return true;

    std::vector<VariableIndex> all_vars = emancipated_vars_;
    std::sort(all_vars.begin(), all_vars.end());
    if (var_assignment(all_vars[0]) != ASSN_U)
      return false;
    for (VariableIndex i = 1; i < all_vars.size(); i++) {
      if (all_vars[i] == all_vars[i-1] || var_assignment(all_vars[i]) != ASSN_U)
        return false;
    }
    return true;
  }
  template <typename T>
  static const bool VerifyNoDuplicates(const std::vector<T> &vec) {
    if (vec.empty())
      return true;

    std::vector<T> vec_copy = vec;
    std::sort(vec_copy.begin(), vec_copy.end());
    for (VariableIndex i = 1; i < vec_copy.size(); i++) {
      if (vec_copy[i] == vec_copy[i-1])
        return false;
    }
    return true;
  }
  const std::vector<VariableIndex>& emancipated_vars() const { return emancipated_vars_; }
  void DeleteEmancipatedVars(const std::vector<VariableIndex> &vars_to_delete) {
    for (auto &var_id : vars_to_delete) {
      auto itr = std::find(emancipated_vars_.begin(), emancipated_vars_.end(), var_id);
      emancipated_vars_.erase(itr);
    }
  }
  void addEmancipatedVars(const std::vector<VariableIndex> &new_emancipated_vars) {
    if (new_emancipated_vars.empty())
      return;

    emancipated_vars_.insert(emancipated_vars_.end(), new_emancipated_vars.begin(),
                             new_emancipated_vars.end());
//    for (auto var_id : new_emancipated_vars) {
//      if (!IsVarEmancipated(var_id))
//        emancipated_vars_.push_back(var_id);
//    }
    assert(VerifyEmancipatedVars());
  }
  void addCachedCompIds(const std::vector<CacheEntryID> &cache_comp_ids) {
    comp_cache_key_.clear();
    cache_comp_ids_.insert(cache_comp_ids_.end(), cache_comp_ids.begin(), cache_comp_ids.end());
    assert(VerifyNoDuplicates<CacheEntryID>(cache_comp_ids_));
  }
  void DecreaseSampleCount(SampleSize dec_sample_count) {
    assert(dec_sample_count > 0 && dec_sample_count < sample_count_);
    sample_count_ -= dec_sample_count;
  }
  void zeroSampleCount() { sample_count_ = 0; }
  inline const VariableIndex num_var() const { return num_var_; }
  void IncorporateCachedAssignment(const CachedAssignment & cached_assn) {
    for (auto lit : cached_assn.literals()) {
      assert(var_assignment(lit.var()) == ASSN_U);  // Only assign to unassigned variables
      setVarAssignment(lit.var(), lit.sign() ? ASSN_T : ASSN_F);
    }
    if (!cached_assn.emancipated_vars().empty())
      emancipated_vars_.insert(emancipated_vars_.end(), cached_assn.emancipated_vars().begin(),
                               cached_assn.emancipated_vars().end());
    assert(VerifyEmancipatedVars());
  }
  inline void stitch(const SampleAssignment &other) {
    assert(VerifyStitchingCompatibility(other));
    for (VariableIndex i = 0; i < assn_.size(); i++)
      assn_[i] &= other.assn_[i];

    addEmancipatedVars(other.emancipated_vars_);
    addCachedCompIds(other.cache_comp_ids_);
    num_vars_set_ += other.num_vars_set_;

    remaining_vars_.clear();
    // Make sure no variables miraculously materialized.
    assert(num_set_vars_const() + emancipated_vars_.size() <= num_var());
  }
  inline SampleAssignment split(SampleSize new_assn_sample_count) {
    DecreaseSampleCount(new_assn_sample_count);
    return SampleAssignment(new_assn_sample_count, *this);
  }
  inline void set_sample_count(SampleSize new_sample_count) { sample_count_ = new_sample_count; }
  inline void BuildRandomizedPartialAssignment(PartialAssignment &all_vars) const {
    GetPartialAssignment(all_vars);
    for (auto var : emancipated_vars_)
      all_vars[var] = (Random::uniform(0, 1)) ? ASSN_F : ASSN_T;
  }
  static SampleAssignment buildUnsetterAssignment(std::vector<VariableIndex> vars_to_keep) {
    SampleAssignment new_assn;
    for (auto &var : vars_to_keep)
      new_assn.setVarAssignment(var, ASSN_F);
    return new_assn;
  }
  void unsetVariableAssignments(SampleAssignment &unsetter) {
    assert(assn_.size() == unsetter.assn_.size());
    for (VariableIndex i = 0; i < assn_.size(); i++)
      assn_[i] |= unsetter.assn_[i];
  }
  std::vector<VariableIndex> GetRemainingVariables() {
    if (remaining_vars_.empty())
      num_set_vars();  // Builds the unset variables list.
    return remaining_vars_;
  }

 public:
  explicit SampleAssignment(SampleSize sample_count) : sample_count_(sample_count) {
    assn_.resize(var_vec_len_, static_cast<typename AssignmentContainer::value_type>(-1));
  }
  SampleAssignment() : SampleAssignment(0) {}
  inline std::string ToString() const {
    std::stringstream ss;
    for (VariableIndex i = FIRST_VAR; i <= num_var(); i++) {
      AssignmentEncoding var_val = var_assignment(i);
      switch (var_val) {
        case ASSN_F: ss << '0'; break;
        case ASSN_T: ss << '1'; break;
        case ASSN_U: ss << '*'; break;
      }
    }
    return ss.str();
  }
  inline const AssignmentEncoding var_assignment(const VariableIndex &var) const {
    VariableIndex word_num, bit_num;
    calculateWordAndBitNumbers(var, word_num, bit_num);
    return (AssignmentEncoding) ((assn_[word_num] >> bit_num) & var_mask_);
  }
  inline void GetPartialAssignment(PartialAssignment &all_vars) const {
    all_vars.clear();
    all_vars.resize(num_var_ + FIRST_VAR);
    for (VariableIndex i = FIRST_VAR; i <= num_var_; i++)
      all_vars[i] = var_assignment(i);
  }
  inline void setVarAssignment(const VariableIndex var, const AssignmentEncoding &val) {
    assert((val == ASSN_F || val == ASSN_T) && var_assignment(var) == ASSN_U);

    VariableIndex word_num, bit_num;
    calculateWordAndBitNumbers(var, word_num, bit_num);

    // Updates the bits of interest only by masking then setting them.
    assn_[word_num] = (assn_[word_num] & ~(var_mask_ << bit_num)) | (val << bit_num);
    num_vars_set_++;
    remaining_vars_.clear();
  }
  static void set_num_var(const VariableIndex num_var) {
    assert(num_var_ == 0);  // This function should onyl be called once.
    num_var_ = num_var;
    var_vec_len_ = ((num_var + 1) * bits_per_var_ / size_of_word_)+ 1;
  }
  const bool IsComplete() const {
//    return cache_comp_ids().empty();
    return num_set_vars_const() + emancipated_vars_.size() == num_var_;
  }
//  /**
//   * Complete Assignment Checker
//   *
//   * Determines whether the sample model is partial or complete.
//   *
//   * @return true if the sample model is a complete assignment.
//   */
//  bool IsComplete() const {
//    return num_set_vars_const() + emancipated_vars_.size() == num_var_;
//  }
//  /**
//   * Accessor for the number of cached component IDs in this object.
//   * @return Number of cached component IDs in this object
//   */
//  inline const uint64_t cached_comp_count() const { return cache_comp_ids_.size(); }
  const SampleSize sample_count() const { return sample_count_; }
  const VariableIndex num_set_vars() {
    if (!remaining_vars_.empty())
      return num_vars_set_;

    // ToDo once variable setting is efficient, just return the set variable count
    num_vars_set_ = 0;
    for (VariableIndex i = FIRST_VAR; i <= num_var(); i++) {
      if (var_assignment(i) != ASSN_U) {
        num_vars_set_++;
      } else if (!IsVarEmancipated(i)) {
        remaining_vars_.emplace_back(i);
      }
    }
    return num_vars_set_;
  }
  const VariableIndex num_set_vars_const() const {
    // ToDo once variable setting is efficient, just return the set variable count
    SampleSize num_vars_set = 0;
    for (VariableIndex i = FIRST_VAR; i <= num_var(); i++)
      if (var_assignment(i) != ASSN_U)
        num_vars_set++;
    return num_vars_set;
  }
//  /**
//   * Updates the assignment of the implicit assignment with that of the specified one.  It does
//   * NOT update the sample count
//   *
//   * @param other Another SampleAssignment.
//   */
//  void updateAssignmentOnly(const SampleAssignment &other) {
//    this->assn_ = other.assn_;
//    this->emancipated_vars_ = other.emancipated_vars_;
//  }
  const VariableIndex num_unset_vars() {
    return num_var_ - num_set_vars() - emancipated_vars_.size();
  }
//  /**
//   * Builds and returns the set of unconstrained variables in this sample.
//   *
//   * @return Identification number of the unset variables.
//   */
//  const std::vector<VariableIndex> GetUnsetConstrainedVars() const {
//    std::vector<VariableIndex> unset_vars;
//    for (VariableIndex i = FIRST_VAR; i <= num_var(); i++)
//      if (var_assignment(i) == ASSN_U && !IsVarEmancipated(i))
//        unset_vars.push_back(i);
//    return unset_vars;
//  }
  const bool IsVarEmancipated(VariableIndex var) const {
    return std::find(emancipated_vars_.begin(), emancipated_vars_.end(), var)
           != emancipated_vars_.end();
  }
  const std::vector<CacheEntryID>& cache_comp_ids() const {
    return cache_comp_ids_;
  }
  void clear_cache_comp_ids() {
    cache_comp_ids_.clear();
  }
//  /**
//   * Generate a unique key for the cached components in the sample assignment.
//   *
//   * @return Key string for the cached components.
//   */
//  std::string GetCachedCompKey() {
//    if (!comp_cache_key_.empty())
//      return comp_cache_key_;
//    std::sort(cache_comp_ids_.begin(), cache_comp_ids_.end());
//    std::stringstream ss;
//    for (auto cached_comp : cache_comp_ids_) {
//      if (cached_comp != cache_comp_ids_[0])
//        ss << ",";
//      ss << cached_comp;
//    }
//    comp_cache_key_ = ss.str();
//    return comp_cache_key_;
//  }
};


typedef std::list<SampleAssignment> ListOfSamples;

class SamplesManager {
 private:
  mpz_class solution_count_ = 0;
  ListOfSamples samples_;
//  /**
//   * Stores the final expected number of samples to be returned at the
//   * end of sampling.
//   */
//  static SampleSize final_num_samples_;
//  /**
//   * Stores the current number of samples being built by the sampler.
//   */
//  static SampleSize samples_manager_vector_size_;
  SampleSize tot_num_samples_;
  SolverConfiguration *config_;
//  /**
//   * Used in the random number generator.  Stores the random bits
//   * to be extracted.
//   */
//  static int random_bits_;
//  /**
//   * Next random bit to be used. It is between [0,NUM_INT_BITS-1).
//   */
//  static int next_rand_bit_;
//  static const int NUM_INT_BITS_;
  static void BuildSample(SampleAssignment &new_sample,
                          const Component *active_comp,
                          const std::vector<LiteralID> &literal_stack,
                          VariableIndex last_branch_lit,
                          const AltComponentAnalyzer &ana,
                          const std::vector<VariableIndex> &freed_vars,
                          const std::vector<CacheEntryID> &cached_comp_ids);
  static bool isVarInLitStack(const VariableIndex var_num,
                              const std::vector<LiteralID> &literal_stack,
                              const VariableIndex start_ofs = 0) {
    for (auto itr = literal_stack.begin() + start_ofs; itr != literal_stack.end(); itr++)
      if ((*itr).var() == var_num)
        return true;
    return false;
  }
  static void buildCnfClauseLiterals(const std::string &input_file_path,
                                     std::vector<std::vector<signed long>> &clauses);
  void splitSampleAndInsert(std::list<SampleAssignment>::iterator &itr,
                            const SampleSize &new_assn_sample_count) {
    assert(new_assn_sample_count > 0 && new_assn_sample_count < itr->sample_count());
    // Need to increment then decrement since this method inserts the new element before itr
    SampleAssignment new_node = itr->split(new_assn_sample_count);
    samples_.insert(itr, new_node);
    --itr;
  }

 public:
  SamplesManager(SampleSize num_samples, SolverConfiguration &config) : config_(&config) {
    tot_num_samples_ = num_samples;
  }
//  /**
//   * Copy constructor.
//   */
//  SamplesManager(const SamplesManager &other)
//      : SamplesManager(other.num_samples(), *other.config_) {}
//  /**
//   * Equality operator.
//   *
//   * @param other Object to which the implicit object will be set.
//   * @return Reference to the new SamplesManager created.  This allows for chaining equality
//   * operators.
//   */
//  SamplesManager& operator=(const SamplesManager &other) {
//    this->tot_num_samples_ = other.tot_num_samples_;
//    this->samples_ = other.samples_;
//    this->config_ = other.config_;
//    return *this;
//  }
  void exportFinal(std::ostream &out, const DataAndStatistics &statistics,
                   const SolverConfiguration& config);
  void reservoirSample(const Component * active_comp,
                       const std::vector<LiteralID> & literal_stack,
                       const mpz_class &solution_weight,
                       const mpz_class &weight_multiplier,
                       const AltComponentAnalyzer &ana,
                       VariableIndex literal_stack_ofs,
                       const std::vector<VariableIndex>& freed_vars,
                       const std::vector<CacheEntryID> &cached_comp_ids,
                       const CachedAssignment& cached_assn,
                       SampleAssignment& cached_sample);
  inline void GenerateSamplesToReplace(const mpz_class &new_sample_weight,
                                       std::vector<SampleSize> &samples_to_replace) const;
//  /**
//   * Sample Variable Assignment Accessor
//   *
//   * Gets the value of the variable assignment for a specific sample in the list of samples
//   *
//   * Debug function.
//   *
//   * @param sample_num Sample number between 0 (inclusive) and num_samples (exclusive)
//   * @param var Variable number
//   *
//   * @return Variable's assigned value.
//   */
//  AssignmentEncoding sample_var_val(const SampleSize sample_num, const VariableIndex var) const {
//    assert(var >= FIRST_VAR && var <= samples_[sample_num].num_var());
//    return samples_[sample_num].var_assignment(var);
//  }
  inline void stitch(SamplesManager &other) {
    assert(this->tot_num_samples_ == other.tot_num_samples_);
    if (solution_count_ == 0) {
      this->samples_ = other.samples_;
      solution_count_ = other.solution_count_;
      return;
    } else {
      solution_count_ *= other.solution_count_;
    }
    // Handle the UNSAT case
    if (other.solution_count_ == 0) {
      this->samples_.clear();
      return;
    }

    // Depending on the number of sample objects, different stitching approaches are faster
    StitchShuffledArray(other);

    // After the size normalization, perform the stitching sample by sample.
    assert(verifyPostStitchingCorrectness(other));
  }
  inline void StitchShuffledArray(SamplesManager &other) {
    assert(this->GetActualSampleCount() == other.GetActualSampleCount());

    std::vector<ListOfSamples::iterator> other_samples_itrs;
    other_samples_itrs.reserve(other.samples_.size());
    std::vector<SampleSize> other_sample_order;
    other_sample_order.reserve(other.tot_num_samples_);

    // Store a reference to each element in other's samples
    // These will be used for simplifying the stitching look-up.
    SampleSize i = 0;
    for (auto other_itr = other.samples_.begin(); other_itr != other.samples_.end(); ++other_itr) {
      other_samples_itrs.emplace_back(other_itr);
      for (SampleSize j=0; j < other_itr->sample_count(); j++)
        other_sample_order.emplace_back(i);
      ++i;
    }
    Random::shuffle<SampleSize>(other_sample_order.begin(), other_sample_order.end());

    // Split and merge to create the permutations
    SampleSize sample_offset = 0;
    for (auto this_itr = samples_.begin(); this_itr != samples_.end(); ) {
      SampleSize sample_end = sample_offset + this_itr->sample_count();
      std::vector<SampleSize> samples_per_element(other_sample_order.size(), 0);
      for (SampleSize sample_cnt = sample_offset; sample_cnt < sample_end; ++sample_cnt)
        samples_per_element[other_sample_order[sample_cnt]]++;

      // Indices from the same "other" sample are grouped together so split and stitch.
      for (SampleSize other_cnt = 0; other_cnt < samples_per_element.size(); ++other_cnt) {
        if (samples_per_element[other_cnt] == 0)
          continue;
        SplitAndStitch(this_itr, other_samples_itrs[other_cnt], samples_per_element[other_cnt]);
      }

      // Update all pointer
      sample_offset = sample_end;
    }
  }
  void SplitAndStitch(ListOfSamples::iterator &this_itr, ListOfSamples::iterator &other_itr,
                      SampleSize &num_new_samples) {
    assert(num_new_samples > 0 && num_new_samples <= this_itr->sample_count()
           && num_new_samples <= other_itr->sample_count());

    // If applicable add a new node
    if (num_new_samples != this_itr->sample_count())
      splitSampleAndInsert(this_itr, num_new_samples);

    this_itr->stitch(*other_itr);
    ++this_itr;
    if (other_itr->sample_count() == num_new_samples)
      other_itr->zeroSampleCount();
    else
      other_itr->DecreaseSampleCount(num_new_samples);
  }
  const bool verifyPostStitchingCorrectness(SamplesManager &other) const {
    for (auto &sample : other.samples_) {
      if (sample.sample_count() != 0) {
        PrintInColor("An other_sample had non-zero size.", PrintColor::COLOR_RED);
        return false;
      }
    }

    if (!this->verifySampleCount()) {
      PrintInColor("The sample count of the stitched object is incorrect.", PrintColor::COLOR_RED);
      return false;
    }

    return true;
  }
  void merge(SamplesManager &other, const mpz_class &other_multiplier,
             const std::vector<VariableIndex> &freed_vars,
             const std::vector<CacheEntryID> &cached_comp_ids,
             const CachedAssignment & cached_assn,
             SampleAssignment& cached_sample);
  bool VerifySolutions(const std::string &input_file_path,
                       bool skip_unassigned = false) const;
  const mpz_class &model_count() const { return solution_count_; }
  void AddEmancipatedVars(const std::vector<VariableIndex> &emancipated_vars) {
    for (auto &sample : samples_)
      sample.addEmancipatedVars(emancipated_vars);

    // Multiply by 2^num_unused_vars since those represent a parallel cylinder
    mpz_mul_2exp(solution_count_.get_mpz_t(), solution_count_.get_mpz_t(),
                 emancipated_vars.size());
  }
  void AddCachedCompIds(const std::vector<CacheEntryID> &cached_comp_ids) {
    for (auto &sample : samples_)
      sample.addCachedCompIds(cached_comp_ids);
  }
  void TransferVariableAssignments(ListOfSamples &others) {
    std::vector<VariableIndex> unset_vars = others.front().GetRemainingVariables();

    assert(!unset_vars.empty());
    auto unsetter = SampleAssignment::buildUnsetterAssignment(unset_vars);
    // Delete redundant emancipated variables.
    for (auto &sample : samples_) {
      sample.unsetVariableAssignments(unsetter);
      sample.DeleteEmancipatedVars(others.front().emancipated_vars());
    }

    SampleSize sample_count = 0;
    for (auto &other : others) {
      sample_count += other.sample_count();
      other.clear_cache_comp_ids(); // Out of data cache ids from previous run.
    }
    assert(sample_count == this->num_samples());
    SamplesManager others_manager(sample_count, *config_);
    others_manager.solution_count_ = 1;
    others_manager.samples_ = others;

    this->stitch(others_manager);
  }
//  /**
//   * Solver Configuration Storer
//   *
//   * This function is used to store a reference to the solver's configuration.
//   * This is useful in case any of the run parameters are used.
//   * @param config Solver's configuration.
//   */
//  static void set_solver_config(SolverConfiguration &config) { config_ = &config; }
  ListOfSamples &samples() { return samples_; }
//  /**
//   * Sample Setter
//   *
//   * This function replaces the sample (based off the specified number)
//   * with the new model passed to the function.
//   *
//   * @param sample_num Sample number to be set
//   * @param new_model New sample model to be stored
//   */
//  inline void set_sample(SampleSize sample_num,
//                         const SampleAssignment &new_model) {
//    assert(sample_num >= 0 && sample_num < samples_.size());
//    samples_[sample_num] = new_model;
//  }
//  /**
//   * Sample Accessor
//   *
//   +   * Extracts a reference to the specified sample from the manager.
//   *
//   +   * @param sample_num Number of the sample to access - base 0
//   *
//   +   * @return Sample at the specified number
//   */
//  inline const SampleAssignment &sample(SampleSize sample_num) const {
//    assert(sample_num >= 0 && sample_num < samples_.size());
//    return samples_[sample_num];
//  }
  inline bool IsComplete() const {
    for (const auto &sample : samples_)
      if (!sample.IsComplete())
        return false;
    return true;
  }
  inline void append(SamplesManager &other) {
    append(other.samples_);
  }
  inline void append(ListOfSamples &other) {
    samples_.splice(samples_.end(), other);
    assert(GetActualSampleCount() <= tot_num_samples_);
  }
  inline const SampleSize num_samples() const {
//    assert(verifySampleCount());
    return tot_num_samples_;
  }
  void RemoveSamples(std::vector<SampleSize> &samples_to_remove) {
    if (samples_to_remove.empty())
      return;
    if (num_samples() == samples_to_remove.size()) {
      samples_.clear();
      return;
    }

    // Delete from back to front to prevent deletion affecting counts.
    assert(!samples_.empty());
    auto sample_itr = --(samples_.end());
    SampleSize cur_sample_start_count = num_samples() - sample_itr->sample_count();
    SampleSize num_to_remove = 0;
    for (auto &sample_to_remove : samples_to_remove) {
      assert(sample_to_remove >= 0 && sample_to_remove < num_samples());
      // Skip to the next sample to remove.
      while (sample_to_remove < cur_sample_start_count) {
        if (num_to_remove > 0) {
          if (sample_itr->sample_count() == num_to_remove)
            // If no samples are left then remove the object
            sample_itr = samples_.erase(sample_itr);
          else
            sample_itr->DecreaseSampleCount(num_to_remove);
          num_to_remove = 0;
        }
        --sample_itr;
        cur_sample_start_count -= sample_itr->sample_count();
      }
      num_to_remove++;
    }

    // Handle the last element that requires removal
    if (sample_itr->sample_count() == num_to_remove)
      sample_itr = samples_.erase(sample_itr);
    else
      sample_itr->DecreaseSampleCount(num_to_remove);
    assert(GetActualSampleCount() == tot_num_samples_ - samples_to_remove.size());
  }
  void KeepSamples(std::vector<SampleSize> &samples_to_keep) {
    if (samples_to_keep.empty()) {
      samples_.clear();
      return;
    }
    if (num_samples() == samples_to_keep.size())
      return;

    // Delete from back to front to prevent deletion affecting counts.
    assert(!samples_.empty());
    auto sample_itr = --(samples_.end());
    SampleSize cur_sample_start_count = num_samples() - sample_itr->sample_count();
    SampleSize num_to_keep = 0;
    for (auto &sample_to_keep : samples_to_keep) {
      assert(sample_to_keep >= 0 && sample_to_keep < num_samples());
      // Skip to the next sample to check
      while (sample_to_keep < cur_sample_start_count) {
        if (num_to_keep > 0) {
          sample_itr->set_sample_count(num_to_keep);
          num_to_keep = 0;
        } else {
          sample_itr = samples_.erase(sample_itr);
        }
        --sample_itr;
        cur_sample_start_count -= sample_itr->sample_count();
      }
      num_to_keep++;
    }

    // Handle the last node to keep
    sample_itr->set_sample_count(num_to_keep);
    // Any nodes never encountered have to be removed.
    if (sample_itr != samples_.begin())
      samples_.erase(samples_.begin(), sample_itr);
    assert(GetActualSampleCount() == samples_to_keep.size());
  }
  const bool verifySampleCount() const {
    return GetActualSampleCount() == tot_num_samples_;
  }
  const SampleSize GetActualSampleCount() const {
    SampleSize actual_sample_count = 0;
    for (auto &sample : samples_) {
      assert(sample.sample_count() > 0);  // Make sure no dead samples
      actual_sample_count += sample.sample_count();
    }
    return actual_sample_count;
  }
  void clear() { samples_.clear(); }
};


// * Recipe Structure Initializer
// *
// * Initializes recipe static structures.  This requires a dedicated function
// * because of the scope requirements of static objects.
// *
// * @param num_var Number of variables in the Boolean formula.
// */
//void InitializeSamplerStructures(VariableIndex num_var);
// * Sample Size Initializer
// *
// * Initializes the number of samples to by collected by the algorithm.
// *
// * @param sample_count Number of samples
// */
//void InitializeSampleCount(SampleSize sample_count);
//
//
//bool IsVarInLiteralStack(const std::vector<LiteralID> &literal_stack, VariableIndex var);

#endif //SHARPSAT_SOLUTION_RECIPE_H