Download as pdf or txt
Download as pdf or txt
You are on page 1of 9

#!

env python

"""Chat server for CST311 Programming Assignment 4"""


__author__ = "Bitwise"
__credits__ = [
"Chris McMichael",
"Jake Kroeker",
"Jeremiah McGrath",
"Miguel Jimenez"
]

# Import statements
from typing import *
import socket as s
import ssl
import threading as t

# Configure logging
import logging
logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

# Constants
SERVER_NAME = "chatserver.pa4.cst311.test"
SERVER_PORT: int = 12000
NUM_CLIENTS: int = 2
CLIENT_NAMES: Tuple[str, str] = ("X", "Y")
BYE_COMMAND = "bye"
CERT_FILE = f"/etc/ssl/demoCA/newcerts/{SERVER_NAME}-cert.pem"
KEY_FILE = f"/etc/ssl/demoCA/private/{SERVER_NAME}-key.pem"

###########
# Classes #
###########

# Handles queueing of messages to be broadcast to clients


class SendQueue:
# A single entry to be sent
class Entry:

def __init__(self, queue, previous_entry=None, next_entry=None,


value=None):
self.queue = queue
self.previous = previous_entry
self.next = next_entry
self.value = value
self.num_consumed_lock: t.Lock = t.Lock()
self.num_consumed = 0

# Consumes the entry's value


# When all consumers consume a value it is no longer referenced
def _consume(self):

value = self.value

# Lock to atomically increment


with self.num_consumed_lock:
self.num_consumed += 1

# Decrement size when all consumers have consumed this entry


queue = self.queue
if self.num_consumed == self.queue.num_consumers:
with queue.size_lock:
queue.size -= 1

return value

# Frees the entry (unlinks it from the list) if all consumers have
processed it
def _check_free_entry(self, next_entry):
# Lock to prevent multiple threads from freeing the entry
queue = self.queue
free_lock = queue.free_lock
free_lock.acquire()

if self.num_consumed == queue.num_consumers:
# Set num_consumed to dummy value so subsequent consumers
don't free the entry
self.num_consumed = -1
free_lock.release()

# The following is safe to do without a lock. In this


context:
# 1) The previous node is always queue.head
# 2) The next node will not change and is NOT queue.tail
self.previous.next = self.next
next_entry.previous = self.previous
else:
free_lock.release()

# Returns true if there is a value waiting to be consumed


def has_next(self):
return self.next is not self.queue.tail

# Waits for the next entry, and returns it along with its value
def wait_for_next(self):

queue = self.queue
next_entry: SendQueue.Entry
# Waits until next_entry is valid (not the tail)
# Lock to check "new entry" condition
with (cond := queue.new_entry_cond):
while (next_entry := self.next) is queue.tail:
cond.wait()

self._check_free_entry(next_entry)
return next_entry, next_entry._consume()

def __init__(self, num_consumers):


self.num_consumers = num_consumers
self.head = SendQueue.Entry(self)
self.tail = SendQueue.Entry(self, self.head)
self.head.next = self.tail
self.add_lock = t.Lock()
self.new_entry_cond = t.Condition()
self.size_lock = t.Lock()
self.size = 0
self.free_lock: t.Lock = t.Lock()

# Returns the queue's head, which acts as the start of the queue for
any consumers
def get_head(self):
return self.head

# Adds a new entry to the end of the queue


def add(self, value):

# Before the actual adding code so that size can never be observed
as < 0 ("new entry"
# condition fulfilled + all consumers consume entry before size
incremented)
with self.size_lock:
self.size += 1

# Lock to prevent multiple threads from calling add()


simultaneously
with self.add_lock:
inserting_after = self.tail.previous
new_node = SendQueue.Entry(self, inserting_after, self.tail,
value)
self.tail.previous = new_node
# This line fulfils the "new entry" condition, which allows
consumer threads to start working
inserting_after.next = new_node

with self.new_entry_cond:
self.new_entry_cond.notify_all()

# May return a number greater than the current size, in which case one
or more messages are about to be added
def get_size(self):
return self.size

# Holds each client handler thread's state


class ClientState:
def __init__(self, client_name: str, secure_socket, consume_entry):
self.threads: list[t.Thread] = []
self.client_name = client_name
self.secure_socket = secure_socket
self.consume_entry = consume_entry

# Globals
client_states: List[ClientState] = []
send_queue: SendQueue = SendQueue(NUM_CLIENTS)
server_print_lock: t.Lock = t.Lock()
running_lock: t.Lock = t.Lock()
running = True

# Prints synchronously
def print_lock(value):
with server_print_lock:
print(value)

# Handles receiving client responses


def client_receive_handler(client_state: ClientState):

global running

client_name = client_state.client_name
secure_socket = client_state.secure_socket

while running:
# Read data from the connection socket
# If no data has been sent this blocks until there is data
client_response = secure_socket.recv(1024).decode()

# Prevents invalid input on chat termination


if client_response != "":
# Format chat message
server_response = f"Client {client_name} : {client_response}"
# Print chat message server-side
print_lock(server_response)

# Handle "bye" command


if client_response.lower() == BYE_COMMAND:
# Locking to ensure that senders see running = False and
the "bye" message synchronously
# 1) running = False + no "bye" message => sender exits
without sending "bye" message
# 2) running = True + "bye" message => sender can block
again before running = False
with running_lock:
# Flag that all threads should exit
running = False
# Add "bye" message to send_queue to be broadcast to
all clients
send_queue.add((client_name,
server_response.encode()))
else:
# Add message to send_queue to be broadcast to all clients
send_queue.add((client_name, server_response.encode()))

# Handles sending messages to clients


def client_send_handler(client_state: ClientState):

client_name = client_state.client_name
secure_socket = client_state.secure_socket

# Continue to send messages while running, or while there are still


unsent messages in the queue
while True:
# Lock so running and has_next() are synchronized
with running_lock:
# Break if there is nothing left to send
if not running and not client_state.consume_entry.has_next():
break

# Wait for and consume client response


client_state.consume_entry, consume_value =
client_state.consume_entry.wait_for_next()
source_client_name, encoded_message = consume_value

# Broadcast client message to all clients except the sender


if client_name != source_client_name:
secure_socket.send(encoded_message)

def main():

# Create ssl context


ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(CERT_FILE, KEY_FILE)

# Create a TCP socket


# Notice the use of SOCK_STREAM for TCP packets
server_socket = s.socket(s.AF_INET, s.SOCK_STREAM)

# To fix subsequent server instances from failing to bind


server_socket.setsockopt(s.SOL_SOCKET, s.SO_REUSEADDR, 1)

# Assign IP address and port number to socket, and bind to chosen port
server_socket.bind(("", SERVER_PORT))

# Configure how many requests can be queued on the server at once


server_socket.listen(1)

# Alert user we are now online


print("The server is ready to receive on port " + str(SERVER_PORT))

# Surround with a try-finally to ensure we clean up the socket after


we're done
try:
# Accept NUM_CLIENTS connections
while (num_client_states := len(client_states)) < NUM_CLIENTS:
# Accept client connection
connection_socket, address = server_socket.accept()
secure_socket = ssl_context.wrap_socket(connection_socket,
server_side=True)
print("Connected to client at " + str(address))

# Create client_state so that the following client handler can


maintain state
client_state: ClientState = ClientState(
CLIENT_NAMES[num_client_states],
secure_socket,
send_queue.get_head()
)

welcome_message = "Welcome to the chat. Your name is: " +


str(client_state.client_name)
secure_socket.send(welcome_message.encode())

# Create client handler threads


receive_thread: t.Thread =
t.Thread(target=client_receive_handler, args=(client_state,))
send_thread: t.Thread = t.Thread(target=client_send_handler,
args=(client_state,))

# Initialize client_state and add it to client_states


client_state.threads.append(receive_thread)
client_state.threads.append(send_thread)
client_states.append(client_state)

# Start the client handler threads


receive_thread.start()
send_thread.start()

# Wait for all client handlers to exit


for client_state in client_states:
for thread in client_state.threads:
thread.join()

finally:
# Close client sockets
for client_state in client_states:
client_state.secure_socket.close()

# Close server socket


server_socket.close()

if __name__ == "__main__":
main()

You might also like