import { AIModelPredict } from './models';

export const caculateModelExpected = (predict: AIModelPredict) => {
  const target = predict.ai_model.spec.target.chg_pct;
  // 模型训练时测试集的收益期望
  const modelExpected = target * predict.prob + predict.loss * (1 - predict.prob);
  return modelExpected;
};

export const caculateStatsExpected = (predict: AIModelPredict) => {
  const target = predict.ai_model.spec.target.chg_pct;
  let stats;
  if (predict.prob >= 0.9) {
    stats = predict.ai_model.spec.stats.filter((s) => s.prob_threshold === 0.9)[0];
  } else {
    stats = predict.ai_model.spec.stats.filter((s) => s.prob_threshold === 0.8)[0];
  }

  if (stats.positive_count < 10) {
    // 样本太少，不具有参考价值
    return 0;
  }
  // 最近实际统计数据的收益期望
  const statsExpected = stats.precision * target + stats.fp_avg_loss * (1 - stats.precision);
  return statsExpected;
}

// 预测期望收益
export const caculateExpectProfit = (predict: AIModelPredict) => {
  const modelExpected = caculateModelExpected(predict);
  let stats;
  if (predict.prob >= 0.9) {
    stats = predict.ai_model.spec.stats.filter((s) => s.prob_threshold === 0.9)[0];
  } else {
    stats = predict.ai_model.spec.stats.filter((s) => s.prob_threshold === 0.8)[0];
  }

  if (stats.positive_count < 10) {
    // 样本太少，不具有参考价值
    return modelExpected;
  }
  // 最近实际统计数据的收益期望
  const statsExpected = caculateStatsExpected(predict);

  return (modelExpected + statsExpected) / 2;
};

// 计算日均期望收益
export const caculateDailyExpectProfit = (predict: AIModelPredict) => {
  const raw = predict.expected.total / predict.ai_model.spec.target.time_limit;
  // const weight = predict.ai_model.detail.val_auc * 0.5 + predict.ai_model.spec.stats.auc * 0.5;
  const weight = predict.ai_model.detail.val_auc * 0.5 + 1 * 0.5;
  // 使用平均AUC作为权重
  return raw * weight;
}

// 计算平均期望收益
export const caculateAvgExpectProfit = (predicts: AIModelPredict[]) => {
  if (predicts.length === 0) {
    return 0;
  }
  let sum = 0;
  predicts.forEach((predict) => {
    sum += caculateDailyExpectProfit(predict);
  });
  return sum / predicts.length;
};


// 计算模型置信度
// 废弃，已放到后端
export const cacualteModelConfidence = (predict: AIModelPredict) => {
  const stats = predict.ai_model.spec.stats;

  let statsDetail;
  if (predict.prob >= 0.9) {
    statsDetail = stats.detail[0.9];
  } else {
    statsDetail = stats.detail[0.8];
  }

  if (statsDetail.positive_count === 0) {
    // 近期都没有正例样本，那么采用测试集的性能
    return 0;
    // const modelDetail = predict.ai_model.detail;
    // return predict.prob * modelDetail.val_auc * modelDetail.val_pre / 100;
  }

  if (statsDetail.tp_pct < 0.5) {
    // 近期正例样本的精确度太低，不具有参考价值
    return 0;
  }

  return predict.prob * predict.ai_model.spec.stats.auc * statsDetail.tp_pct;
};

// 计算整体置信度
export const caculateConfidence = (predicts: AIModelPredict[]) => {
   // 取最大的置信度
    // let maxConfidence = 0;
    // predicts.forEach((predict) => {
    //   const confidence = cacualteModelConfidence(predict);
    //   if (confidence > maxConfidence) {
    //     maxConfidence = confidence;
    //   }
    // });

    // predicts不考虑预警模型
    predicts = predicts.filter((predict) => predict.ai_model.spec.target.chg_pct > 0);
    // predicts只考虑score大于0的
    predicts = predicts.filter((predict) => predict.score > 0);
    // predicts根据score从高到低排序
    predicts = predicts.sort((a, b) => b.score - a.score);
    // 只取前N个预测的总分
    predicts = predicts.slice(0, 3);

    // 计算predicts中前3的总分
    const sum = predicts.reduce((acc, cur) => acc + cur.score, 0);

    return sum.toFixed(2);
};

export const sortPredicts = (predicts: AIModelPredict[]) => {
  // 对prob进行排序，排序依据是近期统计的精确度
  // const sortedProbs = predicts.sort((a, b) => {
  //   let aDetail = a.ai_model.detail;
  //   let bDetail = b.ai_model.detail;

  //   if (Array.isArray(aDetail)) {
  //     if (a.prob >= 0.9) {
  //       aDetail = aDetail.filter((d) => d.threshold === 0.9)[0];
  //     } else {
  //       aDetail = aDetail.filter((d) => d.threshold === 0.8)[0];
  //     }
  //   }


  //   if (Array.isArray(bDetail)) {
  //     if (b.prob >= 0.9) {
  //       bDetail = bDetail.filter((d) => d.threshold === 0.9)[0];
  //     } else {
  //       bDetail = bDetail.filter((d) => d.threshold === 0.8)[0];
  //     }
  //   }

  //   return bDetail.val_pre - aDetail.val_pre;
  // });

  // return sortedProbs;
  return predicts.sort((a, b) => b.score - a.score);
};
