73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
import inspect
|
|
import statistics
|
|
from dataclasses import asdict
|
|
from datetime import timedelta
|
|
from typing import List
|
|
|
|
import requests_cache
|
|
from flask import Flask, jsonify
|
|
|
|
from server.nightr.strategies import dmi, steam
|
|
from server.nightr.util import Context
|
|
|
|
app = Flask(__name__)
|
|
|
|
requests_cache.install_cache("requests_cache.sqlite", expire_after=timedelta(minutes=10))
|
|
|
|
|
|
strategies = {
|
|
# name: (weight, probability function)
|
|
"dmi": (0.5, dmi.probability),
|
|
"steam": (1.0, steam.probability),
|
|
}
|
|
|
|
|
|
@app.route("/", methods=["GET", "POST"])
|
|
def probabilities():
|
|
phone_data = {} # TODO: get from POST request
|
|
context = Context(**phone_data)
|
|
|
|
predictions: List[dict] = []
|
|
for name, (weight, strategy) in strategies.items():
|
|
try:
|
|
prediction = strategy(context)
|
|
except Exception as e:
|
|
print(f"Strategy {name} failed: {e}")
|
|
continue
|
|
predictions.append({
|
|
"name": name,
|
|
"description": inspect.getdoc(strategy),
|
|
"weight": weight,
|
|
"weighted_probability": prediction.probability * weight,
|
|
"night": prediction.probability > 0.5,
|
|
**asdict(prediction),
|
|
})
|
|
|
|
mean = statistics.mean(p["weighted_probability"] for p in predictions)
|
|
median = statistics.median(p["weighted_probability"] for p in predictions)
|
|
night = mean > 0.5
|
|
|
|
# Calculate contributions of predictions
|
|
consensus_weight_sum = sum(p["weight"] for p in predictions if p["night"] == night)
|
|
for prediction in predictions:
|
|
# If this prediction agrees with the consensus it contributed
|
|
if prediction["night"] == night:
|
|
prediction["contribution"] = prediction["weight"] / consensus_weight_sum
|
|
else:
|
|
prediction["contribution"] = 0.0
|
|
|
|
return jsonify({
|
|
"predictions": predictions,
|
|
"weighted_probabilities_mean": mean,
|
|
"weighted_probabilities_median": median,
|
|
"night": night,
|
|
})
|
|
|
|
|
|
def main():
|
|
app.run(host='0.0.0.0')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|