Cache and re-use a shared WebDriver.

master
Pacman Ghost 5 years ago
parent ffed68d236
commit 617e5deda4
  1. 256
      vasl_templates/tools/webdriver_stress_test.py
  2. 14
      vasl_templates/webapp/__init__.py
  3. 2
      vasl_templates/webapp/snippets.py
  4. 27
      vasl_templates/webapp/static/snippets.js
  5. 12
      vasl_templates/webapp/tests/test_scenario_persistence.py
  6. 17
      vasl_templates/webapp/tests/test_snippets.py
  7. 81
      vasl_templates/webapp/tests/utils.py
  8. 63
      vasl_templates/webapp/vassal.py
  9. 95
      vasl_templates/webapp/webdriver.py

@ -0,0 +1,256 @@
#!/usr/bin/env python3
""" Stress-test the shared WebDriver. """
import os
import threading
import signal
import http.client
import time
import datetime
import base64
import random
import json
import logging
from collections import defaultdict
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.keys import Keys
import click
from vasl_templates.webapp.webdriver import WebDriver
from vasl_templates.webapp.tests.test_scenario_persistence import load_scenario
from vasl_templates.webapp.tests.utils import wait_for, find_child, find_snippet_buttons, \
select_tab, select_menu_option, click_dialog_button, set_stored_msg, get_stored_msg
shutdown_event = threading.Event()
thread_count = None
stats_lock = threading.Lock()
stats = defaultdict( lambda: [0,0] ) # nb: [ #runs, total elapsed time ]
# ---------------------------------------------------------------------
@click.command()
@click.option( "--server-url", default="http://localhost:5010", help="Webapp server URL." )
@click.option( "--snippet-images", default=1, help="Number of 'snippet image' threads to run." )
@click.option( "--update-vsav", default=1, help="Number of 'update VSAV' threads to run." )
@click.option( "--vsav","vsav_fname", help="VASL scenario file (.vsav) to be updated." )
def main( server_url, snippet_images, update_vsav, vsav_fname ):
"""Stress-test the shared WebDriver."""
# initialize
logging.disable( logging.CRITICAL )
# read the VASL scenario file
vsav_data = None
if update_vsav > 0:
vsav_data = open( vsav_fname, "rb" ).read()
# prepare the test threads
threads = []
for i in range(0,snippet_images):
threads.append( threading.Thread(
target = snippet_images_thread,
name = "snippet-images/{:02d}".format( 1+i ),
args = ( server_url, )
) )
for i in range(0,update_vsav):
threads.append( threading.Thread(
target = update_vsav_thread,
name = "update-vsav/{:02d}".format( 1+i ),
args = ( server_url, vsav_fname, vsav_data )
) )
# launch the test threads
start_time = time.time()
global thread_count
thread_count = len(threads)
for thread in threads:
thread.start()
# wait for Ctrl-C
def on_sigint( signum, stack ): #pylint: disable=missing-docstring,unused-argument
print( "\n*** SIGINT received ***\n" )
shutdown_event.set()
signal.signal( signal.SIGINT, on_sigint )
while not shutdown_event.is_set():
time.sleep( 1 )
# wait for the test threads to shutdown
for thread in threads:
print( "Waiting for thread to finish:", thread )
thread.join()
elapsed_time = time.time() - start_time
print()
# output stats
print( "=== STATS ===")
print()
print( "Total run time: {}".format( datetime.timedelta( seconds=int(elapsed_time) ) ) )
for key,val in stats.items():
print( "- {:<14} {}".format( key+":", val[0] ), end="" )
if val[0] > 0:
print( " (avg={:.3f}s)".format( float(val[1])/val[0] ) )
else:
print()
# ---------------------------------------------------------------------
def snippet_images_thread( server_url ):
"""Test generating snippet images."""
with WebDriver() as webdriver:
# initialize
webdriver = webdriver.driver
init_webapp( webdriver, server_url,
[ "snippet_image_persistence", "scenario_persistence" ]
)
# load a scenario (so that we get some sortable's)
scenario_data = {
"SCENARIO_NOTES": [ { "caption": "Scenario note #1" } ],
"OB_SETUPS_1": [ { "caption": "OB setup note #1" } ],
"OB_NOTES_1": [ { "caption": "OB note #1" } ],
"OB_SETUPS_2": [ { "caption": "OB setup note #2" } ],
"OB_NOTES_2": [ { "caption": "OB note #2" } ],
}
load_scenario( scenario_data, webdriver )
# locate all the "generate snippet" buttons
snippet_btns = find_snippet_buttons( webdriver )
tab_ids = list( snippet_btns.keys() )
while not shutdown_event.is_set():
try:
# clear the return buffer
ret_buffer = find_child( "#_snippet-image-persistence_", webdriver )
assert ret_buffer.tag_name == "textarea"
webdriver.execute_script( "arguments[0].value = arguments[1]", ret_buffer, "" )
# generate a snippet
tab_id = random.choice( tab_ids )
btn = random.choice( snippet_btns[ tab_id ] )
log( "Getting snippet image: {}", btn.get_attribute("data-id") )
select_tab( tab_id, webdriver )
start_time = time.time()
ActionChains( webdriver ) \
.key_down( Keys.SHIFT ) \
.click( btn ) \
.key_up( Keys.SHIFT ) \
.perform()
# wait for the snippet image to be generated
wait_for( 10*thread_count, lambda: ret_buffer.get_attribute( "value" ) )
_, img_data = ret_buffer.get_attribute( "value" ).split( "|", 1 )
elapsed_time = time.time() - start_time
# update the stats
with stats_lock:
stats["snippet image"][0] += 1
stats["snippet image"][1] += elapsed_time
# FUDGE! Generating the snippet image for a sortable entry is sometimes interpreted as
# a request to edit the entry (Selenium problem?) - we dismiss the dialog here and continue.
dlg = find_child( ".ui-dialog", webdriver )
if dlg and dlg.is_displayed():
click_dialog_button( "Cancel", webdriver )
except ( ConnectionRefusedError, ConnectionResetError, http.client.RemoteDisconnected ):
if shutdown_event.is_set():
break
raise
# check the generated snippet
img_data = base64.b64decode( img_data )
log( "Received snippet image: #bytes={}", len(img_data) )
assert img_data[:6] == b"\x89PNG\r\n"
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def update_vsav_thread( server_url, vsav_fname, vsav_data ):
"""Test updating VASL scenario files."""
# initialize
vsav_data_b64 = base64.b64encode( vsav_data ).decode( "utf-8" )
with WebDriver() as webdriver:
# initialize
webdriver = webdriver.driver
init_webapp( webdriver, server_url,
[ "vsav_persistence", "scenario_persistence" ]
)
# load a test scenario
fname = os.path.join( os.path.split(__file__)[0], "../webapp/tests/fixtures/update-vsav/full.json" )
saved_scenario = json.load( open( fname, "r" ) )
load_scenario( saved_scenario, webdriver )
while not shutdown_event.is_set():
try:
# send the VSAV data to the front-end to be updated
log( "Updating VSAV: {}", vsav_fname )
set_stored_msg( "_vsav-persistence_", vsav_data_b64, webdriver )
select_menu_option( "update_vsav", webdriver )
start_time = time.time()
# wait for the front-end to receive the data
wait_for( 2*thread_count,
lambda: get_stored_msg( "_vsav-persistence_", webdriver ) == ""
)
# wait for the updated data to arrive
wait_for( 60*thread_count,
lambda: get_stored_msg( "_vsav-persistence_", webdriver ) != ""
)
elapsed_time = time.time() - start_time
# get the updated VSAV data
updated_vsav_data = get_stored_msg( "_vsav-persistence_", webdriver )
if updated_vsav_data.startswith( "ERROR: " ):
raise RuntimeError( updated_vsav_data )
updated_vsav_data = base64.b64decode( updated_vsav_data )
# check the updated VSAV
log( "Received updated VSAV data: #bytes={}", len(updated_vsav_data) )
assert updated_vsav_data[:2] == b"PK"
# update the stats
with stats_lock:
stats["update vsav"][0] += 1
stats["update vsav"][1] += elapsed_time
except (ConnectionRefusedError, ConnectionResetError, http.client.RemoteDisconnected):
if shutdown_event.is_set():
break
raise
# ---------------------------------------------------------------------
def log( fmt, *args, **kwargs ):
"""Log a message."""
now = time.time()
msec = now - int(now)
now = "{}.{:03d}".format( time.strftime("%H:%M:%S",time.localtime(now)), int(msec*1000) )
msg = fmt.format( *args, **kwargs )
msg = "{} | {:17} | {}".format( now, threading.current_thread().name, msg )
print( msg )
# ---------------------------------------------------------------------
def init_webapp( webdriver, server_url, options ):
"""Initialize the webapp."""
log( "Initializing the webapp." )
url = server_url + "?" + "&".join( "{}=1".format(opt) for opt in options )
url += "&store_msgs=1" # nb: stop notification balloons from building up
webdriver.get( url )
wait_for( 5, lambda: find_child("#_page-loaded_",webdriver) is not None )
# ---------------------------------------------------------------------
if __name__ == "__main__":
main() #pylint: disable=no-value-for-parameter

@ -2,6 +2,7 @@
import sys
import os
import signal
import configparser
import logging
import logging.config
@ -28,6 +29,16 @@ def load_debug_config( fname ):
# ---------------------------------------------------------------------
cleanup_handlers = []
def on_sigint( signum, stack ): #pylint: disable=unused-argument
"""Clean up after a SIGINT."""
for handler in cleanup_handlers:
handler()
raise SystemExit()
# ---------------------------------------------------------------------
# initialize Flask
app = Flask( __name__ )
@ -65,6 +76,9 @@ if app.config.get( "ENABLE_REMOTE_TEST_CONTROL" ):
print( "*** WARNING: Remote test control enabled! ***" )
import vasl_templates.webapp.testing #pylint: disable=cyclic-import
# install our signal handler (must be done in the main thread)
signal.signal( signal.SIGINT, on_sigint )
# ---------------------------------------------------------------------
@app.context_processor

@ -120,7 +120,7 @@ def make_snippet_image():
# generate an image for the snippet
snippet = request.data.decode( "utf-8" )
try:
with WebDriver() as webdriver:
with WebDriver.get_instance() as webdriver:
img = webdriver.get_snippet_screenshot( None, snippet )
except SimpleError as ex:
return "ERROR: {}".format( ex )

@ -31,21 +31,26 @@ function generate_snippet( $btn, evt, extra_params )
// check if the user is requesting the snippet as an image
if ( evt.shiftKey ) {
// yup - send the snippet to the backend to generate the image
var $dlg = $( "#make-snippet-image" ).dialog( {
dialogClass: "make-snippet-image",
modal: true,
width: 300,
height: 60,
resizable: false,
closeOnEscape: false,
} ) ;
var $dlg = null ;
var timeout_id = setTimeout( function() {
$dlg = $( "#make-snippet-image" ).dialog( {
dialogClass: "make-snippet-image",
modal: true,
width: 300,
height: 60,
resizable: false,
closeOnEscape: false,
} ) ;
}, 1*1000 ) ;
$.ajax( {
url: gMakeSnippetImageUrl,
type: "POST",
data: snippet.content,
contentType: "text/html",
} ).done( function( resp ) {
$dlg.dialog( "close" ) ;
clearTimeout( timeout_id ) ;
if ( $dlg )
$dlg.dialog( "close" ) ;
if ( resp.substr( 0, 6 ) === "ERROR:" ) {
showErrorMsg( resp.substr(7) ) ;
return ;
@ -68,7 +73,9 @@ function generate_snippet( $btn, evt, extra_params )
download( atob(resp), _make_snippet_image_filename(snippet), "image/png" ) ;
}
} ).fail( function( xhr, status, errorMsg ) {
$dlg.dialog( "close" ) ;
clearTimeout( timeout_id ) ;
if ( $dlg )
$dlg.dialog( "close" ) ;
showErrorMsg( "Can't get the snippet image:<div class='pre'>" + escapeHTML(errorMsg) + "</div>" ) ;
} ) ;
return ;

@ -287,12 +287,14 @@ def test_unknown_vo( webapp, webdriver ):
# ---------------------------------------------------------------------
def load_scenario( scenario ):
def load_scenario( scenario, webdriver=None ):
"""Load a scenario into the UI."""
set_stored_msg( "_scenario-persistence_", json.dumps(scenario) )
_ = set_stored_msg_marker( "_last-info_" )
select_menu_option( "load_scenario" )
wait_for( 2, lambda: get_stored_msg("_last-info_") == "The scenario was loaded." )
set_stored_msg( "_scenario-persistence_", json.dumps(scenario), webdriver )
_ = set_stored_msg_marker( "_last-info_", webdriver )
select_menu_option( "load_scenario", webdriver )
wait_for( 2,
lambda: get_stored_msg( "_last-info_", webdriver ) == "The scenario was loaded."
)
def save_scenario():
"""Save the scenario."""

@ -4,7 +4,7 @@ from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.keys import Keys
from vasl_templates.webapp.tests.utils import \
init_webapp, select_tab, select_tab_for_elem, set_template_params, wait_for, wait_for_clipboard, \
init_webapp, select_tab, find_snippet_buttons, set_template_params, wait_for, wait_for_clipboard, \
get_stored_msg, set_stored_msg_marker, find_child, find_children, adjust_html, \
for_each_template, add_simple_note, edit_simple_note, \
get_sortable_entry_count, generate_sortable_entry_snippet, drag_sortable_entry_to_trash, \
@ -33,12 +33,6 @@ def test_snippet_ids( webapp, webdriver ):
def check_snippet( btn ):
"""Generate a snippet and check that it has an ID."""
select_tab_for_elem( btn )
if not btn.is_displayed():
# FUDGE! All nationality-specific buttons are created on each OB tab, and the ones not relevant
# to each player are just hidden. This is not real good since we have multiple elements with the same ID :-/
# but we work around this by checking if the button is visible. Sigh...
return
btn.click()
wait_for_clipboard( 2, "<!-- vasl-templates:id ", contains=True )
@ -49,10 +43,11 @@ def test_snippet_ids( webapp, webdriver ):
set_scenario_date( scenario_date )
# check each snippet
for btn in find_children( "button.generate" ):
check_snippet( btn )
for btn in find_children( "img.snippet" ):
check_snippet( btn )
snippet_btns = find_snippet_buttons()
for tab_id,btns in snippet_btns.items():
select_tab( tab_id )
for btn in btns:
check_snippet( btn )
# test snippets with German/Russian
do_test( "" )

@ -6,6 +6,7 @@ import json
import time
import re
import uuid
from collections import defaultdict
import pytest
from PyQt5.QtWidgets import QApplication
@ -115,31 +116,42 @@ def for_each_template( func ): #pylint: disable=too-many-branches
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def select_tab( tab_id ):
def select_tab( tab_id, webdriver=None ):
"""Select a tab in the main page."""
elem = find_child( "#tabs .ui-tabs-nav a[href='#tabs-{}']".format( tab_id ) )
if not webdriver:
webdriver = _webdriver
elem = find_child( "#tabs .ui-tabs-nav a[href='#tabs-{}']".format( tab_id ), webdriver )
elem.click()
def select_tab_for_elem( elem ):
"""Select the tab that contains the specified element."""
while True:
select_tab( get_tab_for_elem( elem ) )
def get_tab_for_elem( elem ):
"""Identify the tab that contains the specified element."""
while elem.tag_name not in ("html","body"):
elem = elem.find_element_by_xpath( ".." )
if elem.tag_name == "div":
div_id = elem.get_attribute( "id" )
if div_id.startswith( "tabs-" ):
select_tab( div_id[5:] )
break
return div_id[5:]
return None
def select_menu_option( menu_id ):
def select_menu_option( menu_id, webdriver=None ):
"""Select a menu option."""
elem = find_child( "#menu" )
if not webdriver:
webdriver = _webdriver
elem = find_child( "#menu", webdriver )
elem.click()
elem = wait_for_elem( 2, "a.PopMenu-Link[data-name='{}']".format( menu_id ) )
elem = wait_for_elem( 2, "a.PopMenu-Link[data-name='{}']".format( menu_id ), webdriver )
elem.click()
wait_for( 2, lambda: find_child("#menu .PopMenu-Container") is None ) # nb: wait for the menu to go away
if pytest.config.option.webdriver == "chrome": #pylint: disable=no-member
# FUDGE! Work-around weird "is not clickable" errors because the PopMenu is still around :-/
time.sleep( 0.25 )
wait_for( 2, lambda: find_child("#menu .PopMenu-Container",webdriver) is None ) # nb: wait for the menu to go away
try:
if pytest.config.option.webdriver == "chrome": #pylint: disable=no-member
# FUDGE! Work-around weird "is not clickable" errors because the PopMenu is still around :-/
time.sleep( 0.25 )
except AttributeError:
pass
def new_scenario():
"""Reset the scenario."""
@ -323,24 +335,28 @@ def find_sortable_helper( sortable, tag ):
# ---------------------------------------------------------------------
def get_stored_msg( msg_type ):
def get_stored_msg( msg_type, webdriver=None ):
"""Get a message stored for us by the front-end."""
elem = find_child( "#" + msg_type )
if not webdriver:
webdriver = _webdriver
elem = find_child( "#" + msg_type, webdriver )
assert elem.tag_name == "textarea"
return elem.get_attribute( "value" )
def set_stored_msg( msg_type, val ):
def set_stored_msg( msg_type, val, webdriver=None ):
"""Set a message for the front-end."""
elem = find_child( "#" + msg_type )
if not webdriver:
webdriver = _webdriver
elem = find_child( "#" + msg_type, webdriver )
assert elem.tag_name == "textarea"
_webdriver.execute_script( "arguments[0].value = arguments[1]", elem, val )
webdriver.execute_script( "arguments[0].value = arguments[1]", elem, val )
def set_stored_msg_marker( msg_type ):
def set_stored_msg_marker( msg_type, webdriver=None ):
"""Store marker text in the message buffer (so we can tell if the front-end changes it)."""
# NOTE: Care should taken when using this function with "_clipboard_",
# since the tests might be using the real clipboard!
marker = "marker:{}:{}".format( msg_type, uuid.uuid4() )
set_stored_msg( msg_type, marker )
set_stored_msg( msg_type, marker, webdriver )
return marker
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@ -362,6 +378,29 @@ def find_children( sel, parent=None ):
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def find_snippet_buttons( webdriver=None ):
"""Find all "generate snippet" buttons.
NOTE: We only return the 1st snippet button in the "extras" tab.
"""
snippet_btns = defaultdict( list )
# find all normal "generate snippet" buttons
for btn in find_children( "button.generate", webdriver ):
snippet_btns[ get_tab_for_elem(btn) ].append( btn )
# find "generate snippet" buttons in sortable lists
for btn in find_children( "ul.sortable img.snippet", webdriver ):
snippet_btns[ get_tab_for_elem(btn) ].append( btn )
# FUDGE! All nationality-specific buttons are created on each OB tab, and the ones not relevant
# to each player are just hidden. This is not real good since we have multiple elements with
# the same ID :-/ but we work around this by checking if the button is visible. Sigh...
snippet_btns2 = {}
for tab_id,btns in snippet_btns.items():
select_tab( tab_id, webdriver )
snippet_btns2[ tab_id ] = [ btn for btn in btns if btn.is_displayed() ]
return snippet_btns2
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def select_droplist_val( sel, val ):
"""Select a droplist option by value."""
_do_select_droplist( sel, val )
@ -407,10 +446,10 @@ def dismiss_notifications():
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def click_dialog_button( caption ):
def click_dialog_button( caption, webdriver=None ):
"""Click a dialog button."""
btn = next(
elem for elem in find_children(".ui-dialog button")
elem for elem in find_children( ".ui-dialog button", webdriver )
if elem.text == caption
)
btn.click()

@ -126,33 +126,41 @@ def update_vsav(): #pylint: disable=too-many-statements
},
} )
def _save_snippets( snippets, fp ):
def _save_snippets( snippets, fp ): #pylint: disable=too-many-locals
"""Save the snippets in a file.
NOTE: We save the snippets as XML because Java :-/
"""
def do_save_snippets( webdriver ): #pylint: disable=too-many-locals
"""Save the snippets."""
root = ET.Element( "snippets" )
for snippet_id,snippet_info in snippets.items():
# add the next snippet
auto_create = "true" if snippet_info["auto_create"] else "false"
elem = ET.SubElement( root, "snippet", id=snippet_id, autoCreate=auto_create )
elem.text = snippet_info["content"]
label_area = snippet_info.get( "label_area" )
if label_area:
elem.set( "labelArea", label_area )
# add the raw content
elem2 = ET.SubElement( elem, "rawContent" )
for node in snippet_info.get( "raw_content", [] ):
ET.SubElement( elem2, "phrase" ).text = node
# include the size of the snippet
if webdriver:
# NOTE: We used to create a WebDriver here and re-use it for each snippet screenshot,
# but when we implemented the shared WebDriver, we changed things to request it for each
# snippet. If we did things the old way, the WebDriver wouldn't be able to shutdown
# until it had finished *all* the snippet screenshots (since we would have it locked);
# the new way, we only have to wait for it to finish the snippet it's on, the WebDriver
# will be unlocked, and then the other thread will be able to grab the lock and shut
# it down. The downside is that if the user has to disable the shared WebDriver, things
# will run ridiculously slowly, since we will be launching a new webdriver for each snippet.
# We optimize for the case where things work properly... :-/
root = ET.Element( "snippets" )
for snippet_id,snippet_info in snippets.items():
# add the next snippet
auto_create = "true" if snippet_info["auto_create"] else "false"
elem = ET.SubElement( root, "snippet", id=snippet_id, autoCreate=auto_create )
elem.text = snippet_info["content"]
label_area = snippet_info.get( "label_area" )
if label_area:
elem.set( "labelArea", label_area )
# add the raw content
elem2 = ET.SubElement( elem, "rawContent" )
for node in snippet_info.get( "raw_content", [] ):
ET.SubElement( elem2, "phrase" ).text = node
# include the size of the snippet
if not app.config.get( "DISABLE_UPDATE_VSAV_SCREENSHOTS" ):
with WebDriver.get_instance() as webdriver:
try:
start_time = time.time()
img = webdriver.get_snippet_screenshot( snippet_id, snippet_info["content"] )
@ -173,15 +181,8 @@ def _save_snippets( snippets, fp ):
logging.error( "Can't get snippet screenshot: %s", ex )
logging.error( traceback.format_exc() )
ET.ElementTree( root ).write( fp )
return root
# save the snippets
if app.config.get( "DISABLE_UPDATE_VSAV_SCREENSHOTS" ):
return do_save_snippets( None )
else:
with WebDriver() as webdriver:
return do_save_snippets( webdriver )
ET.ElementTree( root ).write( fp )
return root
def _parse_label_report( fname ):
"""Read the label report generated by the VASSAL shim."""

@ -1,26 +1,55 @@
""" Wrapper for a Selenium webdriver. """
import os
import threading
import tempfile
import atexit
import logging
from selenium import webdriver
from PIL import Image, ImageChops
from vasl_templates.webapp import app
from vasl_templates.webapp import app, cleanup_handlers
from vasl_templates.webapp.utils import TempFile, SimpleError
_logger = logging.getLogger( "webdriver" )
# ---------------------------------------------------------------------
class WebDriver:
"""Wrapper for a Selenium webdriver."""
# NOTE: The thread-safety lock controls access to the _shared_instance variable,
# not the WebDriver it points to (it has its own lock).
_shared_instance_lock = threading.RLock()
_shared_instance = None
def __init__( self ):
self.driver = None
self.lock = threading.RLock() # nb: the shared instance must be locked for use
self.start_count = 0
_logger.debug( "Created WebDriver: %x", id(self) )
def __del__( self ):
try:
_logger.debug( "Destroying WebDriver: %x", id(self) )
except NameError:
pass # nb: workaround a weird crash during shutdown (name 'open' is not defined)
def start( self ):
"""Start the webdriver."""
self.lock.acquire()
self._do_start()
def _do_start( self ):
"""Start the webdriver."""
# initialize
self.start_count += 1
_logger.info( "Starting WebDriver (%x): count=%d", id(self), self.start_count )
if self.start_count > 1:
assert self.driver
return
assert not self.driver
# locate the webdriver executable
@ -36,6 +65,7 @@ class WebDriver:
# It's pretty icky to have to do this, but since we're in a virtualenv, it's not too bad...
# create the webdriver
_logger.debug( "- Launching webdriver process: %s", webdriver_path )
kwargs = { "executable_path": webdriver_path }
if "chromedriver" in webdriver_path:
options = webdriver.ChromeOptions()
@ -57,14 +87,23 @@ class WebDriver:
self.driver = webdriver.Firefox( **kwargs )
else:
raise SimpleError( "Can't identify webdriver: {}".format( webdriver_path ) )
return self
_logger.debug( "- Started OK." )
def stop( self ):
"""Stop the webdriver."""
self._do_stop()
self.lock.release()
def _do_stop( self ):
"""Stop the webdriver."""
assert self.driver
self.driver.quit()
self.driver = None
self.start_count -= 1
_logger.info( "Stopping WebDriver (%x): count=%d", id(self), self.start_count )
if self.start_count == 0:
_logger.debug( "- Stopping webdriver process." )
self.driver.quit()
_logger.debug( "- Stopped OK." )
self.driver = None
def get_screenshot( self, html, window_size, large_window_size=None ):
"""Get a preview screenshot of the specified HTML."""
@ -118,6 +157,52 @@ class WebDriver:
window_size, window_size2 = window_size2, None
return self.get_screenshot( snippet, window_size, window_size2 )
@staticmethod
def get_instance():
"""Return the shared WebDriver instance.
A Selenium webdriver has a hefty startup time, so we create one on first use, and then re-use it.
There are 2 main issues with this approach:
- thread-safety: Flask handles requests in multiple threads, so we need to serialize access.
- clean-up: it's difficult to know when to clean up the shared WebDriver object. The WebDriver object
wraps a chrome/geckodriver process, so we can't just let it leak, since these abandoned processes
will just build up. We install atexit and SIGINT handlers, but webdriver processes will still leak
if we abend.
There is a script to stress-test this in the tools directory.
"""
# NOTE: We provide a debug switch to disable the shared instance, in case it causes problems
# (although things will, of course, run insanely slowly :-/).
if app.config.get( "DISABLE_SHARED_WEBDRIVER" ):
return WebDriver()
with WebDriver._shared_instance_lock:
# check if we've already created the shared WebDriver
if WebDriver._shared_instance:
# yup - just return it (nb: the caller is responsible for locking it)
_logger.info( "Returning shared WebDriver: %x", id(WebDriver._shared_instance) )
return WebDriver._shared_instance
# nope - create a new WebDriver instance
# NOTE: We start it here to keep it alive even after the caller has finished with it,
# and take steps to make sure it gets stopped and cleaned up when the program exits.
wdriver = WebDriver()
_logger.info( "Created shared WebDriver: %x", id(wdriver) )
wdriver._do_start() #pylint: disable=protected-access
WebDriver._shared_instance = wdriver
# make sure the shared WebDriver gets cleaned up
def cleanup(): #pylint: disable=missing-docstring
_logger.info( "Cleaning up shared WebDriver: %x", id(wdriver) )
wdriver._do_stop() #pylint: disable=protected-access
atexit.register( cleanup )
cleanup_handlers.append( cleanup )
return wdriver
def __enter__( self ):
self.start()
return self

Loading…
Cancel
Save