/*
 * run_folder.cpp
 *
 */
#include "opt.hpp"
#include "rfold.hpp"

using namespace::RFOLD;

static Rfold folder;

string
usage()
{
  ostringstream oss;
  oss << "usage: run_rfold [options] <seqfile>\n";
  oss << "<seqfile>:\n" 
      << "  sequence file in fasta format\n\n";
  oss << "options:\n"
      << "  -command=<COMPUTE_MEA_FOLD|COMPUTE_PROB>\n"
      << "    COMPUTE_MEA_FOLD predicts local secondary structures by using\n"
      << "    the maximal expected accuracy (MEA) method.\n"
      << "    COMPUTE_PROB only computes the base pairing probabilities.\n"
      << "    (default: COMPUTE_MEA_FOLD)\n"
      << "  -max_pair_dist=<integer>\n"
      << "    set the maximal allowed spans W of base pairs.\n"
      << "    set to (-1) for the computation without the constraint on the\n"
      << "    maximal span. (default: -1)\n"
      << "  -outfile=<filename>\n"
      << "    set the output file for structure predictions. (default: rfold_out.txt)\n"
      << "  -mea_outer_loop_coeff=<float>\n"
      << "    compositional weight Co for outer bases. (default: 1.25)\n"
      << "    The compositional weight Cp for base pairs is always set to Cp=1.0\n"
      << "  -mea_inner_outer_ratio=<float>\n"
      << "    set the ratio Ci/Co with the compositional weights Ci, Co for the inner\n"
      << "    and outer unpaired bases, respectively. (default: 0.75)\n"
      << "  -print_prob=<bool>\n"
      << "    set to true in order to print out the  base pair probabilities\n"
      << "    when -command=COMPUTE_MEA_FOLD (default: false)\n"
      << "  -print_loop_prob=<bool>\n"
      << "    set to true to print out loop probabilities Prob_L(i) (default: false)\n"
      << "  -prob_file=<filename>\n"
      << "    set output file for base pair and loop probabilities (default: rfold_prob.txt)\n";
  return oss.str();
}
const Opt::Entry opt_table[] = {
#define Item(param, default) {#param, Opt::REQUIRED_ARGUMENT, #default}	
  Item(tag,                    rfold),
  Item(command,                COMPUTE_MEA_FOLD),
#if true
  Item(seqfile,                ../script/testdata/seq_data2.fa),
#else
  Item(seqfile,                NO FILE),
#endif
  Item(param_file,             NO FILE),
  Item(param_vector_file,      NO FILE),
  Item(param_string,           NULL),
  Item(outfile,                rfold_out.txt),
  Item(outfile_type,           TAB),
  Item(max_pair_dist,          -1),
  Item(allow_non_canonical_pairs, false),
  Item(mea_separate_loop_type, false),
  Item(mea_outer_loop_coeff,   1.25),
  Item(mea_inner_outer_ratio,  0.75),
  // Item(mea_inner_loop_coeff,   0.753963188937709),
  // Item(mea_inner_scale_by_outer, true),
  Item(print_prob,             false),
  Item(print_loop_prob,        false),
  Item(print_prob_cutoff,      0.0),
  Item(prob_file,              rfold_prob.txt),
  Item(e_value_cutoff,         -1),
  Item(constraint,             false),
  Item(constraint_mea,         false),
  Item(show_dpv,               false),
#undef Item
  {NULL, Opt::NO_ARGUMENT,     NULL}
};
int
main(int argc, char** argv)
{
  Opt opt(opt_table);
  opt.set_from_args(argc, argv);
  if (argc >= 2) {
    cout << "remaining args >= 2" << endl;
    cout << usage();
    exit(EXIT_SUCCESS);
  } else if (argc == 1) {
    opt.set_value("seqfile", string(argv[0]));
  }
  if (opt.get_str("seqfile") == "NO FILE") {
    cout << usage();
    exit(EXIT_SUCCESS);
  }

  // tag
  string tag = opt.get_str("tag");

  // outfile
  ofstream fo(opt.get_str("outfile").c_str());
  if (opt.get_str("command") != "COMPUTE_PROB") {
    Check(fo, "cannot open outfile = %s", opt.get_str("outfile").c_str());
    fo << setprecision(16);
  }
  if (opt.get_str("command") == "COMPUTE_PROB") {
    opt.set_value("print_prob", "true");
  }

  // paramfile
  Model::FeatureCounter fc;
  fc.set_default_params();
  Array<ScoreT> pv0(fc.param_vector());

  if (opt.get_str("param_file") != "NO FILE") {
    ifstream fparam(opt.get_str("param_file").c_str());
    Check(fparam, "cannot open param_file = %s", 
	  opt.get_str("param_file").c_str());
    fc.read_from_file(fparam);
    fparam.close();
    pv0 = fc.param_vector();
  }
  if (opt.get_str("param_vector_file") != "NO FILE") {
    string buf;
    ifstream fparam(opt.get_str("param_vector_file").c_str());
    Check(fparam, "cannot open param_vector_file = %s", 
	  opt.get_str("param_vector_file").c_str());
    getline(fparam, buf);
    fparam.close();
    Vector<string> tokens = tokenize<string>(buf);
    Check((int)tokens.size() == fc.param_vector_size());
    for (int i = 0; i < (int)tokens.size(); i++) {
      pv0[i] = atof(tokens[i].c_str());
    }  
  }
  if (opt.get_str("param_string") != "NULL") {
    const char delim = ',';
    string str = opt.get_str("param_string");
    for (int i = 0; i < (int)str.size(); i++) {
      if (str[i] == delim) {
	str[i] = ' ';
      }
    }
    Vector<string> tokens = tokenize<string>(str);
    Check(tokens.size() % 2 == 0, 
	  "bad format for param_string=%s", opt.get_str("param_string").c_str());
    int nparam = (tokens.size() / 2);
    for (int i = 0; i < nparam; i++) {
      int idx = fc.param_index(tokens[2*i]);
      Check(idx >= 0, "unrecognized param %s", tokens[2*i].c_str());
      pv0[idx] = atof(tokens[2*i+1].c_str());
    }
  }
  fc.set_param_vector(pv0);

  // seqfile
  ifstream fseq(opt.get_str("seqfile").c_str());
  Check(fseq, "cannot open seqfile %s", opt.get_str("seqfile").c_str());
  SeqFile seq_file(tag + "_sequence.txt");

  folder.set_command(opt.get_str("command"));
  folder.set_max_pair_dist(opt.get_int("max_pair_dist"));
  folder.set_allow_non_canonical(opt.get_bool("allow_non_canonical_pairs"));
  folder.set_mea_separate_loop_type(opt.get_bool("mea_separate_loop_type"));
  folder.set_mea_outer_loop_coeff(opt.get_dbl("mea_outer_loop_coeff"));
  // if (opt.get_bool("mea_inner_scale_by_outer")) {
  //   folder.set_mea_inner_loop_coeff(opt.get_dbl("mea_outer_loop_coeff")
  // 				    * opt.get_dbl("mea_inner_loop_coeff"));
  // } else {
  //   folder.set_mea_inner_loop_coeff(opt.get_dbl("mea_inner_loop_coeff"));
  // }
  folder.set_mea_inner_loop_coeff(opt.get_dbl("mea_outer_loop_coeff")
				  * opt.get_dbl("mea_inner_outer_ratio"));
  folder.set_print_prob(opt.get_bool("print_prob"));
  folder.set_print_loop_prob(opt.get_bool("print_loop_prob"));
  folder.set_print_prob_cutoff(opt.get_dbl("print_prob_cutoff"));
  folder.set_constraint(opt.get_bool("constraint"));
  folder.set_constraint_mea(opt.get_bool("constraint_mea"));
  folder.set_param_vector(pv0);
  cout << "max_pair_dist=" << opt.get_int("max_pair_dist") << endl;
  ofstream f_prob;
  if (opt.get_bool("print_prob")) {
    f_prob.open(opt.get_str("prob_file").c_str());
    Check(f_prob, "cannot open probfile %s", opt.get_str("prob_file").c_str());
    folder.set_f_prob(&f_prob);
  } else {
    folder.set_f_prob(0);
  }

  int max_pair_dist = opt.get_int("max_pair_dist");
  bool allow_non_canonical_pairs = opt.get_bool("allow_non_canonical_pairs");
  string cmd = opt.get_str("command");
  if (cmd == "COMPUTE_EM_F" || cmd == "COMPUTE_CRF_F") {
    ScoreT dp_score = 0.0;
    int m = 0;
    while (seq_file.read_fasta(fseq)) {
      folder.set_seq(seq_file);
      if (seq_file.has_ct()) {
	if (max_pair_dist >= 0) {// modify ct
	  seq_file.ct().remove_distant_pairs(max_pair_dist);
	}
	if (!allow_non_canonical_pairs) {
	  seq_file.remove_non_canonical_pairs();
	}
	folder.set_ct_constr(seq_file.ct());
      }
      folder.run();
      const ScoreT& sc1 = folder.dp_score();
      dp_score += sc1;
      m++;
    }
    dp_score /= (ScoreT)m;
    fo << "---\n";
    fo << "dp_score: " << dp_score << "\n";
    fo << "param_vector0:\n";
    for (int i = 0; i < (int)pv0.size(); i++) {
      fo << "- " << pv0[i] << "\n";
    }
    cout << "dp_score: " << dp_score << endl;
  } else if (cmd == "COMPUTE_EM_FDF" || cmd == "COMPUTE_CRF_FDF") {
    Array<ScoreT> count_vector(pv0.size());
    count_vector.fill(0.0);
    ScoreT dp_score = 0.0;
    int m = 0;
    while (seq_file.read_fasta(fseq)) {
      folder.set_seq(seq_file);
      if (seq_file.has_ct()) {
	if (max_pair_dist >= 0) {// modify ct
	  seq_file.ct().remove_distant_pairs(max_pair_dist);
	}
	if (!allow_non_canonical_pairs) {
	  seq_file.remove_non_canonical_pairs();
	}
	folder.set_ct_constr(seq_file.ct());
	// cout << "ct_constr: " << folder.ct_constr().sscons() << endl;
      }
      folder.run();
      const Array<ScoreT>& cv1 = folder.count_vector();
      const ScoreT& sc1 = folder.dp_score();
      for (int k = 0; k < (int)cv1.size(); k++) {
	count_vector[k] += cv1[k];
      }
      dp_score += sc1;
      m++;
    }
    fseq.close();
    dp_score /= (ScoreT)m;
    for (int k = 0; k < (int)count_vector.size(); k++) {
      count_vector[k] /= (ScoreT)m;
    }

    fo << "---\n";
    fo << "dp_score: " << dp_score << "\n";
    fo << "param_vector0:\n";
    for (int i = 0; i < (int)pv0.size(); i++) {
      fo << "- " << pv0[i] << "\n";
    }
    fo << "count_vector:\n";
    for (int i = 0; i < (int)count_vector.size(); i++) {
      fo << "- " << count_vector[i] << "\n";
    }

    cout << "dp_score: " << dp_score << endl;

    // show dpv
    if (opt.get_bool("show_dpv")) {
      string cmd1;
      if (cmd == "COMPUTE_EM_FDF") cmd1 = "COMPUTE_EM_F";
      else if (cmd == "COMPUTE_CRF_FDF") cmd1 = "COMPUTE_CRF_F";
      else Die();
#if true
      folder.set_command(cmd1);
      Array<ScoreT> count_vector1(pv0.size());
      count_vector1.fill(0.0);
      ifstream fseq1(opt.get_str("seqfile").c_str());
      int m = 0;
      while (seq_file.read_fasta(fseq1)) {
	folder.set_seq(seq_file);
	if (seq_file.has_ct()) {
	  if (max_pair_dist >= 0) {// modify ct
	    seq_file.ct().remove_distant_pairs(max_pair_dist);
	  }
	  if (!allow_non_canonical_pairs) {
	    seq_file.remove_non_canonical_pairs();
	  }
	  folder.set_ct_constr(seq_file.ct());
	}
	folder.run();
	const Array<ScoreT>& cv = folder.compute_dpv();
	for (int k = 0; k < (int)cv.size(); k++) {
	  count_vector1[k] += cv[k];
	}
	m++;
      }
      fseq1.close();
      for (int k = 0; k < (int)count_vector1.size(); k++) {
	count_vector1[k] /= (ScoreT)m;
      }
      cout << "count_vector count_vector1\n";
      for (int i = 0; i < (int)count_vector.size(); i++) {
	const string name = fc.param_name(i);
	ScoreT z = count_vector[i];
	ScoreT z1 = count_vector1[i];
	ScoreT diff = (z - z1);
	ScoreT eps = 1.0e-8;
	if ((abs(diff) > eps * (1.0 + abs(z) + abs(z1)))) {
	  // if (z > 0.5) {
	  cout << name << " " << z << " " << z1 << " " << diff << "\n";
	}
      }
#else
      cout << "name count score energy\n";
      for (int i = 0; i < (int)count_vector.size(); i++) {
	const string name = fc.param_name(i);
	ScoreT z = count_vector[i];
	ScoreT z1 = pv0[i];
	ScoreT z2 = Model::score_to_energy(z1);
	if (z > 0.0) {
	  cout << setw(20) <<  name << " " << z << " " << z1 << " " << z2 << "\n";
	}
      }
#endif
    }
    
  } else if (cmd == "COMPUTE_ML_FOLD" || cmd == "COMPUTE_MEA_FOLD") {
    if (opt.get_str("outfile_type") == "TAB") {
      folder.set_f_struct(&fo);
    }
    int m = 0;
    while (seq_file.read_fasta(fseq)) {
      folder.set_seq(seq_file);
      if (seq_file.has_ct()) {
	if (max_pair_dist >= 0) {// modify ct
	  seq_file.ct().remove_distant_pairs(max_pair_dist);
	}
	if (!allow_non_canonical_pairs) {
	  seq_file.remove_non_canonical_pairs();
	}
	folder.set_ct_constr(seq_file.ct());
      }
      folder.run();
      ScoreT dp_score = folder.dp_score();
     
      if (opt.get_str("outfile_type") == "YML") {
	string sscons = folder.sscons();
	fo << "---\n";
	fo << "dp_score: " << dp_score << "\n";
	fo << "sscons: " << sscons << "\n";
      }

      // if (cmd == "COMPUTE_MEA_FOLD") {
      // 	cout << "partition_coeff: " << folder.partition_coeff() << endl;
      // }
      // cout << "dp_score: " << dp_score << endl;

      m++;
    }
  } else if (cmd == "COMPUTE_PROB") {
    folder.set_print_prob(true);
    if (!f_prob.is_open()) {
      f_prob.open(opt.get_str("prob_file").c_str());
      Check(f_prob, "cannot open probfile %s", opt.get_str("prob_file").c_str());
    }
    folder.set_f_prob(&f_prob);

    int m = 0;
    while (seq_file.read_fasta(fseq)) {
      folder.set_seq(seq_file);
      if (seq_file.has_ct()) {
	if (max_pair_dist >= 0) {// modify ct
	  seq_file.ct().remove_distant_pairs(max_pair_dist);
	}
	if (!allow_non_canonical_pairs) {
	  seq_file.remove_non_canonical_pairs();
	}
	folder.set_ct_constr(seq_file.ct());
      }
      folder.run();
      // ScoreT dp_score = folder.dp_score();
      // string sscons = folder.sscons();
      // fo << "---\n";
      // fo << "dp_score: " << dp_score << "\n";
      // cout << "dp_score: " << dp_score << endl;
      m++;
    }
  }
  fo.close();
  return 0;
}
