Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_postgres_connect(monkeypatch):
db_params = DbParams(dbtype='PG',
host='server', port='1521', dbname='testdb',
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
expected_conn_str = 'host=server port=1521 dbname=testdb user=testuser password=mypassword'
mock_connect = Mock()
monkeypatch.setattr(psycopg2, 'connect', mock_connect)
# Act
helper = PostgresDbHelper()
helper.connect(db_params, 'DB_PASSWORD')
# Assert
mock_connect.assert_called_with(expected_conn_str)
These currently run against internal BGS instance.
"""
# pylint: disable=unused-argument, missing-docstring
from datetime import datetime, date
import os
from textwrap import dedent
import cx_Oracle
import pytest
from etlhelper import connect, get_rows, copy_rows, DbParams
from etlhelper.exceptions import ETLHelperConnectionError
from test.conftest import db_is_unreachable
# Skip these tests if database is unreachable
ORADB = DbParams.from_environment(prefix='TEST_ORACLE_')
if db_is_unreachable(ORADB.host, ORADB.port):
pytest.skip('Oracle test database is unreachable', allow_module_level=True)
# -- Tests here --
def test_connect():
conn = connect(ORADB, 'TEST_ORACLE_PASSWORD')
assert isinstance(conn, cx_Oracle.Connection)
def test_connect_wrong_password(monkeypatch):
monkeypatch.setitem(os.environ, 'TEST_ORACLE_PASSWORD', 'bad_password')
with pytest.raises(ETLHelperConnectionError):
connect(ORADB, 'TEST_ORACLE_PASSWORD')
def test_sqlserver_connect(monkeypatch):
db_params = DbParams(dbtype='MSSQL',
host='server', port='1521', dbname='testdb',
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
expected_conn_str = ('DRIVER=test driver;SERVER=tcp:server;PORT=1521;'
'DATABASE=testdb;UID=testuser;PWD=mypassword')
mock_connect = Mock()
monkeypatch.setattr(pyodbc, 'connect', mock_connect)
# Act
helper = SqlServerDbHelper()
helper.connect(db_params, 'DB_PASSWORD')
# Assert
mock_connect.assert_called_with(expected_conn_str)
def test_from_db_params_not_registered():
"""
Tests helpful error message on attempt to choose unregistered db_params
type.
"""
db_params = MagicMock(DbParams)
db_params.dbtype = 'Not a real type'
with pytest.raises(ETLHelperHelperError,
match=r'Unsupported DbParams.dbtype.*'):
DB_HELPER_FACTORY.from_db_params(db_params)
def test_sqlite_dbparam_not_supported():
sqlitedb = DbParams(
dbtype='SQLITE',
filename='sqlite.db',
dbname='etlhelper',
user='etlhelper_user')
with pytest.raises(ValueError):
sqlitedb.is_reachable()
def sqlitedb(tmp_path):
"""Get DbParams for temporary SQLite database."""
filename = f'{tmp_path.absolute()}.db'
yield DbParams(dbtype='SQLITE', filename=filename)
def test_sqlserver_sqlalchemy_connect(monkeypatch):
db_params = DbParams(dbtype='MSSQL',
host='server', port='1521', dbname='testdb',
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = SqlServerDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
expected_conn_str = 'mssql+pyodbc://testuser:mypassword@server:1521/testdb?driver=test+driver'
assert conn_str == expected_conn_str
def test_postgres_sqlalchemy_connect(monkeypatch):
db_params = DbParams(dbtype='PG',
host='server', port='1521', dbname='testdb',
user='testuser')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = PostgresDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
expected_conn_str = 'postgresql://testuser:mypassword@server:1521/testdb'
assert conn_str == expected_conn_str
def test_oracle_sqlalchemy_conn_string(monkeypatch):
db_params = DbParams(dbtype='ORACLE',
host='server', port='1521', dbname='testdb',
user='testuser')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = OracleDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
expected_conn_str = ('oracle://testuser:mypassword@server:1521/testdb')
assert conn_str == expected_conn_str
def params():
return DbParams(dbtype='ORACLE', odbc_driver='test driver', host='testhost',
port=1521, dbname='testdb', user='testuser')