Skip to content

qlearning on tensorflow. 0.84 success rate #5

Open
wants to merge 2 commits into
base: nlu-train
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 0 additions & 2 deletions src/deep_dialog/checkpoints/rl_agent/checkpoint

This file was deleted.

1 change: 0 additions & 1 deletion src/deep_dialog/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self,**kwargs):
self.epsilon_max = kwargs.get('e_greedy', 0.9)
self.batch_size = kwargs.get('batch_size', 16)
self.epsilon_increment = kwargs.get('e_greedy_increment')
#self.epsilon = 0 if kwargs.get('e_greedy_increment', False) else self.epsilon_max
self.epsilon = 1

self.double_q = kwargs.get('double_q', True) # decide to use double q or not
Expand Down
2 changes: 1 addition & 1 deletion src/draw_learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def read_performance_records(path):
""" load the performance score (.json) file """

data = json.load(open(path, 'r'))
data = json.load(open(path, 'rt'))
numbers = {'x': [], 'success_rate': [], 'ave_turns': [], 'ave_rewards': []}
for key in data['success_rate'].keys():
if int(key) > -1:
Expand Down
219 changes: 109 additions & 110 deletions src/telegram_bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def get_random_emoji(num = 1):
msg = input()
episode_over, agent_ans = dia_manager.next_turn(msg)
turn_count+=1
# bot.send_message(msg, agent_ans+' ' + get_random_emoji(1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а почему здесь был отсыл сообщения ботом, а теперь нет?

Copy link
Author

@lightforever lightforever Jun 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я не знаю. В твоём замечании по этому куску было написано "убрать или раскомментировать"

я и убрал

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я буду выражаться яснее в следующий раз, имелось в виду следующее: не должно быть закомментированного кода, либо он раскомментирован, либо его нет

у тебя этот кусок кода был закомментирован, я указал, что это неправильно; правильно - понять, что там должно быть и - например, раскомментированная строка, или какая-то другая строка; кажется, что функционал в этом месте изменился после убирания этого кода, что нехорошо

print("turn #{}: {}".format(turn_count, agent_ans))
if episode_over:
turn_count = 0
Expand All @@ -99,112 +98,112 @@ def get_random_emoji(num = 1):



# @bot.message_handler(commands=['help'])
# def handle_help(message):
# help_message = "Hello, friend!\n" + get_random_emoji(4) + \
# "I can help you to buy tickets " + emojize(":ticket:")+" to the cinema.\n" \
# "=====================================\n" \
# "* Print /start to start a conversation;\n" \
# "* Print /end to end the dialog;\n"
#
# bot.send_message(message.chat.id, help_message)
#
#
# @bot.message_handler(commands=['start'])
# def handle_start(message):
# global turn_count
# turn_count = 1
# greetings = "Hello! I can help you to buy tickets to the cinema.\nWhat film would you like to watch?"
# bot.send_message(message.chat.id, greetings)
#
#
# @bot.message_handler(commands=['end'])
# def handle_end(message):
# global turn_count
# turn_count = 0
# goodbye = "Farewell! Let me know if you would like to buy tickets again." + get_random_emoji()
# bot.send_message(message.chat.id, goodbye)
#
# @bot.message_handler(commands=['films'])
# def show_films(message):
# '''
#
# These should return a
# list of available films for user
#
# '''
#
# available_films = []
# warning = 'currently not available'
# if len(available_films) == 0:
# bot.send_message(message.chat.id, warning)
# else:
# bot.send_message(message.chat.id, available_films)
#
#
# @bot.message_handler(func=lambda message: True, content_types=["text"])
# def handle_text(message):
# global turn_count
# if turn_count > 0:
# if turn_count == 1:
# dia_manager.initialize_episode()
#
# episode_over, agent_ans = dia_manager.next_turn(message.text)
# turn_count+=1
# bot.send_message(message.chat.id, agent_ans+' ' + get_random_emoji(1))
# if episode_over:
# turn_count = 0
# else:
# bot.reply_to(message, message.text)
#
#
# if WEBHOOKS_AVAIL:
#
# WEBHOOK_HOST = config['WEBHOOK_HOST']
# WEBHOOK_PORT = config['WEBHOOK_PORT']
# WEBHOOK_LISTEN = config['WEBHOOK_LISTEN']
#
# WEBHOOK_SSL_CERT = config['WEBHOOK_SSL_CERT'] ## sertificat path
# WEBHOOK_SSL_PRIV = config['WEBHOOK_SSL_PRIV'] ## private key path
#
# WEBHOOK_URL_BASE = "https://%s:%s" % (WEBHOOK_HOST, WEBHOOK_PORT)
# WEBHOOK_URL_PATH = "/%s/" % config['token']
#
#
# class WebhookServer(object):
# @cherrypy.expose
# def index(self):
# if 'content-length' in cherrypy.request.headers and \
# 'content-type' in cherrypy.request.headers and \
# cherrypy.request.headers['content-type'] == 'application/json':
# length = int(cherrypy.request.headers['content-length'])
# json_string = cherrypy.request.body.read(length).decode("utf-8")
# update = telebot.types.Update.de_json(json_string)
# bot.process_new_updates([update])
# return ''
# else:
# raise cherrypy.HTTPError(403)
#
#
# # Снимаем вебхук перед повторной установкой (избавляет от некоторых проблем)
# bot.remove_webhook()
#
# # Ставим заново вебхук
# bot.set_webhook(url=WEBHOOK_URL_BASE + WEBHOOK_URL_PATH,
# certificate=open(WEBHOOK_SSL_CERT, 'r'))
#
# # Указываем настройки сервера CherryPy
# cherrypy.config.update({
# 'server.socket_host': WEBHOOK_LISTEN,
# 'server.socket_port': WEBHOOK_PORT,
# 'server.ssl_module': 'builtin',
# 'server.ssl_certificate': WEBHOOK_SSL_CERT,
# 'server.ssl_private_key': WEBHOOK_SSL_PRIV
# })
#
#
# cherrypy.quickstart(WebhookServer(), WEBHOOK_URL_PATH, {'/': {}})
#
# else:
# bot.delete_webhook()
# bot.polling(none_stop=True) ## uncomment it for local testing;
@bot.message_handler(commands=['help'])
def handle_help(message):
help_message = "Hello, friend!\n" + get_random_emoji(4) + \
"I can help you to buy tickets " + emojize(":ticket:")+" to the cinema.\n" \
"=====================================\n" \
"* Print /start to start a conversation;\n" \
"* Print /end to end the dialog;\n"

bot.send_message(message.chat.id, help_message)


@bot.message_handler(commands=['start'])
def handle_start(message):
global turn_count
turn_count = 1
greetings = "Hello! I can help you to buy tickets to the cinema.\nWhat film would you like to watch?"
bot.send_message(message.chat.id, greetings)


@bot.message_handler(commands=['end'])
def handle_end(message):
global turn_count
turn_count = 0
goodbye = "Farewell! Let me know if you would like to buy tickets again." + get_random_emoji()
bot.send_message(message.chat.id, goodbye)

@bot.message_handler(commands=['films'])
def show_films(message):
'''

These should return a
list of available films for user

'''

available_films = []
warning = 'currently not available'
if len(available_films) == 0:
bot.send_message(message.chat.id, warning)
else:
bot.send_message(message.chat.id, available_films)


@bot.message_handler(func=lambda message: True, content_types=["text"])
def handle_text(message):
global turn_count
if turn_count > 0:
if turn_count == 1:
dia_manager.initialize_episode()

episode_over, agent_ans = dia_manager.next_turn(message.text)
turn_count+=1
bot.send_message(message.chat.id, agent_ans+' ' + get_random_emoji(1))
if episode_over:
turn_count = 0
else:
bot.reply_to(message, message.text)


if WEBHOOKS_AVAIL:

WEBHOOK_HOST = config['WEBHOOK_HOST']
WEBHOOK_PORT = config['WEBHOOK_PORT']
WEBHOOK_LISTEN = config['WEBHOOK_LISTEN']

WEBHOOK_SSL_CERT = config['WEBHOOK_SSL_CERT'] ## sertificat path
WEBHOOK_SSL_PRIV = config['WEBHOOK_SSL_PRIV'] ## private key path

WEBHOOK_URL_BASE = "https://%s:%s" % (WEBHOOK_HOST, WEBHOOK_PORT)
WEBHOOK_URL_PATH = "/%s/" % config['token']


class WebhookServer(object):
@cherrypy.expose
def index(self):
if 'content-length' in cherrypy.request.headers and \
'content-type' in cherrypy.request.headers and \
cherrypy.request.headers['content-type'] == 'application/json':
length = int(cherrypy.request.headers['content-length'])
json_string = cherrypy.request.body.read(length).decode("utf-8")
update = telebot.types.Update.de_json(json_string)
bot.process_new_updates([update])
return ''
else:
raise cherrypy.HTTPError(403)


# Снимаем вебхук перед повторной установкой (избавляет от некоторых проблем)
bot.remove_webhook()

# Ставим заново вебхук
bot.set_webhook(url=WEBHOOK_URL_BASE + WEBHOOK_URL_PATH,
certificate=open(WEBHOOK_SSL_CERT, 'r'))

# Указываем настройки сервера CherryPy
cherrypy.config.update({
'server.socket_host': WEBHOOK_LISTEN,
'server.socket_port': WEBHOOK_PORT,
'server.ssl_module': 'builtin',
'server.ssl_certificate': WEBHOOK_SSL_CERT,
'server.ssl_private_key': WEBHOOK_SSL_PRIV
})


cherrypy.quickstart(WebhookServer(), WEBHOOK_URL_PATH, {'/': {}})

else:
bot.delete_webhook()
bot.polling(none_stop=True) ## uncomment it for local testing;