Q2.ipynb - Colab

You might also like

Download as pdf or txt
Download as pdf or txt
You are on page 1of 3

5/12/24, 8:37 AM Q2.

ipynb - Colab

1 import numpy as np
2 import time
3 import math

1 class Agent:
2 def __init__(self):
3 self.balls_played = 0
4 self.runs_scored = 0
5 self.wickets_down = 0
6 self.last_played = 0
7
8 self.pulls = np.zeros(6, dtype=np.float32)
9 self.arm_wickets = np.zeros(6, dtype=np.float32)
10 self.arm_runs = np.zeros(6, dtype=np.float32)
11
12 self.ucb_arms = np.zeros(6, dtype=np.float32)
13 self.runs = np.array([0, 1, 2, 3, 4, 6])
14
15 def calculate_ucb(self):
16 for action in range(6):
17 mean = self.arm_runs[action] / self.pulls[action]
18 term = (
19 2
20 * math.log(1 + self.balls_played * (math.log(self.balls_played) ** 2))
21 / self.pulls[action]
22 )
23 self.ucb_arms[action] = mean + math.sqrt(term)
24
25 def get_action(self, wicket, runs_scored):
26 action = None
27
28 if self.balls_played == 0:
29 action = 0
30 else:
31 self.runs_scored += runs_scored
32 self.wickets_down += wicket
33
34 self.pulls[self.last_played] += 1
35 self.arm_wickets[self.last_played] += wicket
36 self.arm_runs[self.last_played] += runs_scored
37
38 if self.balls_played < 6:
39 action = self.balls_played
40 else:
41 self.calculate_ucb()
42 maxucb = np.amax(self.ucb_arms)
43 indices = np.where(self.ucb_arms == maxucb)
44 action = np.amax(indices)
45
46 self.last_played = action
47 self.balls_played += 1
48 return action

https://colab.research.google.com/drive/1TPAmFOjsgK5CGwUKBhIbQVtrRyzn1yXu#printMode=true 1/3
5/12/24, 8:37 AM Q2.ipynb - Colab
1 class Environment:
2 def __init__(self, num_balls, agent):
3 self.num_balls = num_balls
4 self.agent = agent
5 self.__run_time = 0
6 self.__total_runs = 0
7 self.__total_wickets = 0
8 self.__runs_scored = 0
9 self.__start_time = 0
10 self.__end_time = 0
11 self.__regret_w = 0
12 self.__regret_s = 0
13 self.__wicket = 0
14 self.__regret_rho = 0
15 self.__p_out = np.array([0.001, 0.01, 0.02, 0.03, 0.1, 0.3])
16 self.__p_run = np.array([1, 0.9, 0.85, 0.8, 0.75, 0.7])
17 self.__action_runs_map = np.array([0, 1, 2, 3, 4, 6])
18 self.__s = (1 - self.__p_out) * self.__p_run * self.__action_runs_map
19 self.__rho = self.__s / self.__p_out
20
21 def __get_action(self):
22 self.__start_time = time.time()
23 action = self.agent.get_action(self.__wicket, self.__runs_scored)
24 self.__end_time = time.time()
25 self.__run_time = self.__run_time + self.__end_time - self.__start_time
26 return action
27
28 def __get_outcome(self, action):
29 pout = self.__p_out[action]
30 prun = self.__p_run[action]
31 wicket = np.random.choice(2, 1, p=[1 - pout, pout])[0]
32 runs = 0
33 if wicket == 0:
34 runs = (
35 self.__action_runs_map[action]
36 * np.random.choice(2, 1, p=[1 - prun, prun])[0]
37 )
38 return wicket, runs
39
40 def innings(self):
41 self.__total_runs = 0
42 self.__total_wickets = 0
43 self.__runs_scored = 0
44
45 for ball in range(self.num_balls):
46 action = self.__get_action()
47 self.__wicket, self.__runs_scored = self.__get_outcome(action)
48 self.__total_runs = self.__total_runs + self.__runs_scored
49 self.__total_wickets = self.__total_wickets + self.__wicket
50 self.__regret_w = self.__regret_w + (
51 self.__p_out[action] - np.min(self.__p_out)
52 )
53 self.__regret_s = self.__regret_s + (np.max(self.__s) - self.__s[action])
54 self.__regret_rho = self.__regret_rho + (
55 np.max(self.__rho) - self.__rho[action]
56 )
57 return (
58 self.__regret_w,
59 self.__regret_s,
60 self.__regret_rho,
61 self.__total_runs,
62 self.__total_wickets,
63 self.__run_time,
64 )

https://colab.research.google.com/drive/1TPAmFOjsgK5CGwUKBhIbQVtrRyzn1yXu#printMode=true 2/3
5/12/24, 8:37 AM Q2.ipynb - Colab
1 agent = Agent()
2 environment = Environment(1000, agent)
3 regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time = (
4 environment.innings()
5 )
6
7 print(regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time)

216.82300000000305 137.4969999999991 69904.50000000116 2813 231 0.03276634216308594

https://colab.research.google.com/drive/1TPAmFOjsgK5CGwUKBhIbQVtrRyzn1yXu#printMode=true 3/3

You might also like