#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import sys
def usage(command):
print("Usage: {0} #yny #ynn #nyy #nyn #yyy #yyn #nny #nnn".format(command))
print(" baseline proposed answer")
print(" #yny yes no yes
#ynn yes no no")
print(" #nyy no yes yes
#nyn no yes no")
print(" #yyy yes yes yes
#yyn yes yes no")
print(" #nny no no yes
#nnn no no no")
sys.exit()
def read_argv():
argvs = sys.argv
for i in range(1, 9):
if len(argvs) <= i or not argvs[i].isdigit():
usage(argvs[0])
arg = list(map(int, argvs[1:9]))
if arg[4] == 0:
arg[4] = 1.0 / 10 ** 16
print("baseline-proposed-answer:", end="")
print("#yny = {0}, #ynn = {1}, #nyy = {2}, #nyn = {3}".format(arg[0], arg[1], arg[2], arg[3]), end="")
print("#yyy = {0}, #yyn = {1}, #nny = {2}, #nnn = {3}".format(arg[4], arg[5], arg[6], arg[7]))
return arg
def diff_f(arg, print_flag):
p_b = float(arg[0] + arg[4]) / (arg[0] + arg[1] + arg[4] + arg[5])
r_b = float(arg[0] + arg[4]) / (arg[0] + arg[2] + arg[4] + arg[6])
p_p = float(arg[2] + arg[4]) / (arg[2] + arg[3] + arg[4] + arg[5])
r_p = float(arg[2] + arg[4]) / (arg[0] + arg[2] + arg[4] + arg[6])
(f_b, f_p) = (2.0 * p_b * r_b / (p_b + r_b), 2.0 * p_p * r_p / (p_p + r_p))
if print_flag:
print("[p_b: {0:.3f}, r_b: {1:.3f}, f_b: {2:.3f}],".format(p_b, r_b, f_b), end="")
print("[p_p: {0:.3f}, r_p: {1:.3f}, f_p: {2:.3f}],".format(p_p, r_p, f_p), end="")
print("[f_p - f_b = {0:.5f}],".format(f_p - f_b), end="")
sys.stdout.flush()
return f_p - f_b
def binomial_dist_list(n, cumulative_flag):
if n == 0:
return [1.0]
bin_dist = [0.0 for _ in range(n + 1)]
if cumulative_flag:
(bin_dist[0], bin_dist[n]) = (1, 2 ** n)
else:
(bin_dist[0], bin_dist[n]) = (1, 1)
(sum_comb, comb) = (1, 1)
for i in range(1, int((n + 1) // 2)):
comb = comb * (n + 1 - i) // i
if cumulative_flag:
sum_comb += comb
(bin_dist[i], bin_dist[n - i - 1]) = (sum_comb, 2 ** n - sum_comb)
else:
(bin_dist[i], bin_dist[n - i]) = (comb, comb)
if cumulative_flag:
bin_dist[n - 1] = 2 ** n -1
elif n % 2 == 0:
bin_dist[n // 2] = comb * (n // 2 + 1) // (n // 2)
return bin_dist
def calc_count(arg, actual_diff, num_y, num_n):
count = 0
bin_dist_y = binomial_dist_list(int(num_y), False)
c_bin_dist_n = binomial_dist_list(int(num_n), True)
for i in range(0, int(num_y) + 1):
(arg[0], arg[2]) = (i, num_y - i)
for j in range(0, int(num_n) + 1):
(arg[1], arg[3]) = (num_n - j, j)
if actual_diff > diff_f(arg, False):
if j > 0:
count += bin_dist_y[i] * c_bin_dist_n[j - 1]
break
if j == int(num_n) and diff_f(arg, False) >= actual_diff:
count += bin_dist_y[i] * c_bin_dist_n[j]
return count
def print_prob(prob, num):
for i in range(0, num):
if not isinstance(prob, float):
prob //= 2
if prob < 10 ** 16:
prob = float(prob)
else:
prob /= 2
print("p = {0:.16f}".format(prob * 2))
if __name__=='__main__':
arg = read_argv()
(diff, num_y, num_n) = (abs(diff_f(arg, True)), arg[0] + arg[2], arg[1] + arg[3])
if diff == 0:
print("p = {0:.16f}".format(1.0))
else:
count = calc_count(arg, diff, num_y, num_n)
print_prob(count, int(num_y + num_n))