#!/usr/bin/env python3
# coding=utf-8
# Copyright (c) Microsoft. All rights reserved.

# Helper python script with utility methods

import sys
import os
import grp
import pwd
import getpass
import unicodedata
import subprocess
import gettext
import locale
import re
from configparser import ConfigParser
import logging as rootLogger

# Any messages logged through this logger do not require localization since it is hidden behind the -v flag.
logger = rootLogger.getLogger("mssql-conf")

#
# Static configuration values
#
sqlPathRoot = "/var/opt/mssql"
sqlPathLogDir = "/var/opt/mssql/log"
configurationFilePath = os.path.join(sqlPathRoot, "mssql.conf")
masterDatabaseFilePath = os.path.join(sqlPathRoot, "data", "master.mdf")
eulaConfigSection="EULA"
eulaConfigSetting="accepteula"
eulaMlConfigSetting="accepteulaml"
licensingConfigSection = "licensing"
productCoveredBySASetting="productcoveredbysa"
telemetryConfigSetting="customerfeedback"
telemetryLocalAuditCacheDirectorySetting="userrequestedlocalauditdirectory"
errorExitCode = 1
successExitCode = 0
directoryOfScript = os.path.dirname(os.path.realpath(__file__))
checkInstallScript = directoryOfScript + "/checkinstall.sh"
sqlBinPathRoot = "/opt/mssql/bin"
launchpaddPath = sqlBinPathRoot + "/launchpadd"
mlservicesPathRoot = "/opt/mssql/mlservices"
checkInstallExtensibilityScript = sqlBinPathRoot + "/checkinstallextensibility.sh"
checkRunningInstanceScript = directoryOfScript + "/checkrunninginstance.sh"
invokeSqlservrScript = directoryOfScript + "/invokesqlservr.sh"
setCollationScript = directoryOfScript + "/set-collation.sh"
sudo = "sudo"
mssqlUser = "mssql"
saPasswordEnvVariable = "SA_PASSWORD"
mssqlSaPasswordEnvVariable = "MSSQL_SA_PASSWORD"
mssqlLcidEnvVariable = "MSSQL_LCID"
mssqlPidEnvVariable = "MSSQL_PID"
language = "language"
lcid = "lcid"
expressEdition = "express"
evaluationEdition = "evaluation"
developerEdition = "developer"
webEdition = "web"
standardEdition = "standard"
enterpriseEdition = "enterprise"
enterpriseCoreEdition = "enterprisecore"
isAzureBilledSetting = "azurebilling"
supportedLcids = ['1033', '1031', '3082', '1036', '1040',
                  '1041', '1042', '1046', '1049', '2052', '1028']
pidPositiveResponse = "1"
pidNegativeResponse = "0"
pidInvalidResponse = "-1"

#
# Colors to use when printing to standard out
#
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    RED = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

#
# Error codes for failed password validation
#
class passwordErrorCodes:
    SUCCESS        = 0
    TOO_SHORT      = 1
    TOO_LONG       = 2
    NOT_COMPLEX    = 3
    ENCODING_ERROR = 4
    DECODING_ERROR = 5
    CONTROL_CHARS  = 6

def printError(text):
    """printError

    Args:
        text(str): Text to print
    """

    _printTextInColor(text, bcolors.RED)
    logger.error(text)

def printException(text):
    """ Prints out the text like printError, but instead of logger.error,
        use logger.exception to print out most recent stacktrace.
    """

    _printTextInColor(text, bcolors.RED)
    logger.exception(text)

def printErrorUnsupportedSetting(section_name, setting_name=None):
    """Print error message and exit for unsupported settings

    Args:
        section_name(str): The section name
                           For legacy reasons section_name can contain setting_name as well
                           For example: section_name.setting_name
        setting_name(str): The setting name
    """

    if setting_name:
        print(_("The section '%s' and setting '%s' is not supported.") % (section_name, setting_name))
    else:
        print((_("The setting '%s' is not supported.") % section_name))
    exit(errorExitCode)

def checkColorSupported():
    """Check if color is supported

    Returns:
        True if color is supported, False otherwise
    """

    plat = sys.platform
    supported_platform = plat != 'Pocket PC' and (plat != 'win32' or 'ANSICON' in os.environ)
    is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()

    if not supported_platform or not is_a_tty:
        return False

    return True

def setSettingsFromEnvironment():
    """Set settings from environment
    """

    import mssqlsettingsmanager

    ret = True

    for setting in mssqlsettingsmanager.supportedSettingsList:
        if setting.environment_variable is not None and setting.environment_variable in os.environ:
            # Make a local copy of environment variable and then remove
            # it from the environment. Otherwise it will be passed on
            # to SQL Server when running setup.
            #
            value = os.environ[setting.environment_variable]
            del os.environ[setting.environment_variable]

            ret = setSetting(setting, value)

            if not ret:
                break

    return ret

def setSettingByName(section, name, value, overrideConfigFilePath=None):
    """ Set a value in mssql.conf
    """

    import mssqlsettingsmanager

    logger.info("Changing setting with name [%s] in section [%s] to value [%s].", name, section, value)
    setting = mssqlsettingsmanager.findSetting(section, name)

    if setting == None or not setSetting(setting, value, overrideConfigFilePath):
        message = _("Warning: Failed to set '%s.%s' to '%s' in mssql.conf. Please set this manually") % (section, name, value)
        logger.warning(message)
        print(message)
        return False
    else:
        logger.info("Successfully changed setting.")
        return True

def setSetting(setting, value, overrideConfigFilePath=None):
    """ Set a value in mssql.conf
    """

    import mssqlsettingsmanager

    if overrideConfigFilePath is None:
        configPath = configurationFilePath
    else:
        configPath = overrideConfigFilePath

    config = ConfigParser()
    readConfigFromFile(config, configPath)

    if len(value) > 0:
        if setting.sectionOnly:
            if mssqlsettingsmanager.setSectionOnlySettings(config, setting, value, True):
                writeConfigToFile(config, configPath)
            else:
                return False
        elif mssqlsettingsmanager.setSetting(config, setting, value, True):
            writeConfigToFile(config, configPath)
        else:
            return False
    else:
        if setting.sectionOnly:
            if mssqlsettingsmanager.unsetSectionOnlySettings(config, setting):
                writeConfigToFile(config, configPath)
        elif mssqlsettingsmanager.unsetSetting(config, setting):
            writeConfigToFile(config, configPath)

    return True

def languageSelect(noprompt):
    """Select language

    Args:
        noprompt(boolean): True if --noprompt specified, false otherwise
    """

    lcidFromEnv = os.environ.get(mssqlLcidEnvVariable)

    if (lcidFromEnv != None):
        print(_("Setting language using LCID from environment variable %s") % mssqlLcidEnvVariable)
        writeLcidToConfFile(lcidFromEnv)
        return

    if(noprompt == False):
        language = locale.getdefaultlocale()[0]
        if(language == None or language == "" or language.lower() == "en_us"):
            # Nothing to do as en_US will be chosen by default by the engine
            return
        else:
            print("")
            print(_("Choose the language for SQL Server:"))
            print("(1) English")
            print("(2) Deutsch")
            print("(3) Español")
            print("(4) Français")
            print("(5) Italiano")
            print("(6) 日本語")
            print("(7) 한국어")
            print("(8) Português")
            print("(9) Руѝѝкий")
            print("(10) 中文 – 简体")
            print("(11) 中文 （繝体）")

            languageOption = input(_("Enter Option 1-11: "))

            optionToLcid = { '1': '1033', #en-US
                     '2': '1031', #de-DE
                     '3': '3082', #es-ES
                     '4': '1036', #fr-FR
                     '5': '1040', #it-IT
                     '6': '1041', #ja-JP
                     '7': '1042', #ko-KR
                     '8': '1046', #pt-BR
                     '9': '1049', #ru-RU
                     '10': '2052', #zh-CN
                     '11': '1028'} #zh-TW

            if (languageOption in list(optionToLcid.keys())):
                writeLcidToConfFile(optionToLcid[languageOption])
            else:
                print(_("Invalid Option. Exiting."))
                exit(errorExitCode)


def isValidLcid(lcidValue):
    """Check if a LCID value is valid.

    Args:
        lcidValue(int): LCID value
    """

    return lcidValue in supportedLcids

def writeLcidToConfFile(lcidValue):
    """Write LCID to configuration file

    Args:
        lcidValue(int): LCID value
    """

    if (isValidLcid(lcidValue) == False):
        print(_("LCID %s is not supported.") % lcidValue)
        exit(errorExitCode)

    config = ConfigParser(allow_no_value=True)
    readConfigFromFile(config, configurationFilePath)

    if (config.has_section(language) == False):
        config.add_section(language)

    config.set(language, lcid, lcidValue)
    writeConfigToFile(config, configurationFilePath)

def printEngineWelcomeMessage():
    """Print engine welcome message
    """

    print ("")
    print ("+--------------------------------------------------------------+")
    print(_("Please run 'sudo /opt/mssql/bin/mssql-conf setup'"))
    print(_("to complete the setup of Microsoft SQL Server"))
    print ("+--------------------------------------------------------------+")
    print ("")
    return successExitCode

def printAgentWelcomeMessage():
    """Print agent welcome message
    """

    print ("")
    print ("+--------------------------------------------------------------------------------+")
    print(_("Please restart mssql-server to enable Microsoft SQL Server Agent."))
    print ("+--------------------------------------------------------------------------------+")
    print ("")

def printFTSWelcomeMessage():
    """Print full-text welcome message
    """

    print ("")
    print ("+-------------------------------------------------------------------------------------+")
    print(_("Please restart mssql-server to enable Microsoft SQL Server Full Text Search."))
    print ("+-------------------------------------------------------------------------------------+")
    print ("")

def printPolyBaseWelcomeMessage():
    """Print PolyBase welcome message
    """

    print ("")
    print ("+--------------------------------------------------------------------------------+")
    print(_("Please restart mssql-server to enable PolyBase."))
    print ("+--------------------------------------------------------------------------------+")
    print ("")

def printPolyBaseHadoopWelcomeMessage():
    """Print PolyBase Hadoop welcome message
    """

    print ("")
    print ("+--------------------------------------------------------------------------------+")
    print(_("Please restart mssql-launchpadd to enable PolyBase Hadoop."))
    print ("+--------------------------------------------------------------------------------+")
    print ("")

def printMachineLearningServicesWelcomeMessage():
    """Print the Machine Learning services welcome message
    """

    print ("")
    print ("+-------------------------------------------------------------------------------------+")
    print(_("Please run 'sudo /opt/mssql/bin/mssql-conf setup accept-eula-ml'"))
    print(_("to accept the Machine Learning Services EULA."))
    print ("+-------------------------------------------------------------------------------------+")
    print ("")

def getFwlinkWithLocale(linkId):
    """Gets the correct Url for the fwlink based on the users locale

    Args:
        linkId(string): The fwlink ID

    Returns:
        The string with the complete url
    """

    baseUrl = "https://go.microsoft.com/fwlink/?LinkId=" + linkId
    localeCode = locale.getlocale()[0]
    localeToClcid = {'en_US': '0x409',  # en-US
                     'de_DE': '0x407',  # de-DE
                     'es_ES': '0x40a',  # es-ES
                     'fr_FR': '0x40c',  # fr-FR
                     'it_IT': '0x410',  # it-IT
                     'ja_JP': '0x411',  # ja-JP
                     'ko_KR': '0x412',  # ko-KR
                     'pt_BR': '0x416',  # pt-BR
                     'ru_RU': '0x419',  # ru-RU
                     'zh_CN': '0x804',  # zh-CN
                     'zh_TW': '0x404'}  # zh-TW

    if localeCode in localeToClcid:
        return baseUrl + "&clcid=" + localeToClcid[localeCode]
    else:
        return baseUrl


def checkEulaAgreement(eulaAccepted, configurationFilePath, ignoreMasterDatabase=False, isEvaluationEdition = False):
    """Check if the EULA agreement has been accepted.

    Args:
        eulaAccepted(boolean): User has indicated their acceptance via command-line
                               or environment variable.
        configurationFilePath(str): Configuration file path
        ignoreMasterDatabase(boolean): Ignore presence of master database
        isEvaluationEdition(boolean): True if edition selected is evaluation, false otherwise

    Returns:
        True if accepted, False otherwise
    """

    print(_("The license terms for this product can be found in"))
    print(("/usr/share/doc/mssql-server " + _("or downloaded from: https://aka.ms/useterms")))
    print("")
    print(_("The privacy statement can be viewed at:"))
    print((getFwlinkWithLocale("853010")))
    print("")

    if os.path.exists(masterDatabaseFilePath) and not ignoreMasterDatabase:
        return True

    config = ConfigParser(allow_no_value=True)
    readConfigFromFile(config, configurationFilePath)

    if (config.has_section(eulaConfigSection) == False or \
        config.get(eulaConfigSection, eulaConfigSetting) is None):
        if not eulaAccepted:
            agreement = input(_("Do you accept the license terms?") + " [Yes/No]:")
            print("")

            if (agreement.strip().lower() == "yes" or agreement.strip().lower() == "y"):
                eulaAccepted = True
            else:
                return False

        if eulaAccepted:
            config.add_section(eulaConfigSection)
            config.set(eulaConfigSection, eulaConfigSetting, "Y")
            writeConfigToFile(config, configurationFilePath)
            return True

    return True

def checkEulaMlAgreement(eulaMlAccepted, configurationFilePath):
    """Check if the EULA agreement for machine learning services has been accepted.

    Args:
        eulaMlAccepted(boolean): User has indicated their acceptance via command-line
                               or environment variable.
        configurationFilePath(str): Configuration file path

    Returns:
        True if accepted, False otherwise
    """

    print(_("The license terms for this product can be downloaded from:"))
    print((getFwlinkWithLocale("2006040")))

    config = ConfigParser(allow_no_value=True)
    readConfigFromFile(config, configurationFilePath)

    if (config.has_section(eulaConfigSection) == True and \
            (config.has_option(eulaConfigSection, eulaConfigSetting) and config.get(eulaConfigSection, eulaConfigSetting) == "Y")):
        if (config.has_option(eulaConfigSection, eulaMlConfigSetting) and config.get(eulaConfigSection, eulaMlConfigSetting) == "Y"):
            return True
        else:
            if not eulaMlAccepted:
                agreement = input(_("Do you accept the license terms for machine learning services?") + " [Yes/No]:")
                print("")

                if (agreement.strip().lower() == "yes" or agreement.strip().lower() == "y"):
                    eulaMlAccepted = True
                else:
                    return False

            if eulaMlAccepted:
                config.set(eulaConfigSection, eulaMlConfigSetting, "Y")
                writeConfigToFile(config, configurationFilePath)
                return True

    else:
        print(_("The SQL Server EULA needs to be accepted first"))
        print ("")
        return False

def checkSudo():
    """Check if we're running as root

    Returns:
        True if running as root, False otherwise
    """

    if (os.geteuid() == 0):
        return True

    return False

def checkSudoOrMssql():
    """Check if we're running as root or the user is in the mssql group.

    Returns:
        True if running as root or in mssql group, False otherwise
    """

    if(checkSudo() == True):
        return True

    user = getpass.getuser()
    groups = [g.gr_name for g in grp.getgrall() if user in g.gr_mem]
    gid = pwd.getpwnam(user).pw_gid
    groups.append(grp.getgrgid(gid).gr_name)

    if('mssql' in groups):
        return True

    return False

def printValidationErrorMessage(setting, errorMessage):
    """Print a validation error message.

    Args:
        setting(str): Setting name
        errorMessage(str): Error message
    """

    printError(_("Validation error on setting '%s.%s'") % (setting.section, setting.name))
    printError(errorMessage)

def printPasswordErrorMessage(errorCode):
    """Print an error message if password can't be validated.

     Args:
         errorCode(int): Error code
    """

    if errorCode == passwordErrorCodes.TOO_SHORT:
        printError((_("The specified password does not meet SQL Server password policy requirements because it is "
                      "too short. The password must be at least 8 characters")))
    elif errorCode == passwordErrorCodes.TOO_LONG:
        printError((_("The specified password does not meet SQL Server password policy requirements because it is "
                      "too long. The password cannot exceed 128 characters")))
    elif errorCode == passwordErrorCodes.NOT_COMPLEX:
        printError((_("The specified password does not meet SQL Server password policy requirements because it "
                      "is not complex enough. The password must be at least 8 characters long and contain characters "
                      "from three of the following four sets: uppercase letters, lowercase letters, numbers, "
                      "and symbols.")))
    elif errorCode == passwordErrorCodes.ENCODING_ERROR:
        printError(_("The specified password contains a character that cannot be encoded to UTF-8. "
                     "Try using a password with only ASCII characters."))
    elif errorCode == passwordErrorCodes.DECODING_ERROR:
        printError(_("The specified password contains a character that cannot be decoded. "
                     "Try using a password with only ASCII characters."))
    elif errorCode == passwordErrorCodes.CONTROL_CHARS:
        printError((_("The specified password contains an invalid character. Valid characters "
                      "include uppercase letters, lowercase letters, numbers, symbols, "
                      "punctuation marks, and unicode characters that are categorized as alphabetic "
                      "but are not uppercase or lowercase.")))

def makeDirectoryIfNotExists(directoryPath):
    """Make a directory if it does not exist

    Args:
        directoryPath(str): Directory path
    """

    try:
        if os.path.exists(directoryPath):
            return
        if not os.path.exists(os.path.dirname(directoryPath)):
            makeDirectoryIfNotExists(os.path.dirname(directoryPath))
        os.makedirs(directoryPath)
    except (IOError, OSError) as err:
        if err.errno == 13:
            printError(_("Permission denied to mkdir '%s'.") % (directoryPath))
            exit(errorExitCode)
        else:
            printError(err)
            exit(errorExitCode)

def writeConfigToFile(config, configurationFilePath):
    """Write configuration to a file

    Args:
        config(object): Config parser object
        configurationFilePath(str): Configuration file path
    """
    logger.info("Writing configuration to file: [%s]", configurationFilePath)

    makeDirectoryIfNotExists(os.path.dirname(configurationFilePath))

    try:
        with open(configurationFilePath, 'w') as configFile:
            config.write(configFile)
    except (IOError, OSError) as err:
        if err.errno == 13:
            printError(_("Permission denied to modify SQL Server configuration."))
        else:
            printError(err)
            exit(errorExitCode)

def readConfigFromFile(config, configurationFilePath):
    """"Read configuration from a file

    Args:
        config(object): Config parser object
        configurationFilePath(str): Configuration file path
    """

    if (os.path.exists(configurationFilePath) == True):
        try:
            config.read(configurationFilePath)
        except:
            printError(_("There was a parsing error in the configuration file."))
            exit(errorExitCode)

def getSettings(configuration_file_path, section_name, setting_name=None):
    """
    Gets the settings and their values from the conf file.
    If only section_name is specified returns all settings in that section.
    Args:
        configuration_file_path(str): The path of the mssql-conf configuration file
        section_name(str): The name of the section a setting belongs to
        setting_name(str): Optional the setting name whose value to get

    Returns:
        A dictionary of setting_name : setting_value,
        Empty dictionary if cannot find the section or the setting
    """
    config = ConfigParser()
    results = {}
    readConfigFromFile(config, configuration_file_path)

    try:
        if not setting_name:
            for section in config.sections():
                if section_name.lower() == section.lower():
                    for setting in config.items(section):
                        results[setting[0]] = setting[1]
        else:
            results[setting_name] = config.get(section_name, setting_name)

        return results

    except Exception:
        return {}

def listSupportedSettings(supportedSettingsList):
    """List supported settings
    """

    maxLength = 0

    for setting in supportedSettingsList:
        settingLength = len("%s.%s" % (setting.section, setting.name))
        if settingLength > maxLength:
            maxLength = settingLength

    def getSettingSortKey(item):
        return "%s.%s" % (item.section, item.name)

    formatString = "%-" + str(maxLength) + "s %s"
    for setting in sorted(supportedSettingsList, key=getSettingSortKey):
        if setting.hidden == False:
            print((formatString % ("%s.%s" % (setting.section, setting.name), setting.description)))

    exit(successExitCode)

def printRestartRequiredMessage():
    """Print a message telling the user that SQL Server needs to be restarted.
    """

    print(_("SQL Server needs to be restarted in order to apply this setting. Please run"))
    print(_("'systemctl restart mssql-server.service'."))

def printLaunchpadRestartRequiredMessage():
    """Print a message telling the user that SQL Server Extensibility Launchpad Daemon needs to be restarted.
    """

    print(_("SQL Server Extensibility Launchpad Daemon needs to be restarted in order to apply this setting. Please run"))
    print(_("'systemctl restart mssql-launchpadd.service'."))

def printStartSqlServerMessage():
    """Print a message telling the user to start SQL Server.
    """

    print(_("Please run 'sudo systemctl start mssql-server' to start SQL Server."))

def printRevoScalePyNotInstalledMessage():
    """Print a message telling the user that a revoscalepy installation was not found for their python installation.
    """

    print(_("pythonbinpath has been set successfully, however, a revoscalepy installation was not found."))
    print(_("revoscalepy is required to be installed in order to execute Python scripts."))

def printRevoScaleRNotInstalledMessage():
    """Print a message telling the user that a RevoScaleR installation was not found for their R installation.
    """

    print(_("rbinpath has been set successfully, however, a RevoScaleR installation was not found."))
    print(_("RevoScaleR is required to be installed in order to execute R scripts."))

def validateCollation(collation):
    """Validate collation

    Args:
        collation(str): Collation name
    """

    directoryOfScript = os.path.dirname(os.path.realpath(__file__))
    with open(directoryOfScript + '/collations.txt') as f:
        supportedCollationsList = [line.strip() for line in f.readlines()]
    if collation.lower() in (collation.lower() for collation in supportedCollationsList):
        return True
    else:
        printError(_("'%s' is not a supported SQL Server collation.") % collation)
        return False

def validatePasswordAndPrintIfError(password):
    """Validate a password and print an error if it's not formed correctly.

    Args:
         password(str): Password

    Returns:
         True if valid, False otherwise
    """

    passwordValidationResult = isValidPassword(password)

    if passwordValidationResult != passwordErrorCodes.SUCCESS:
        printPasswordErrorMessage(passwordValidationResult)
        return False
    else:
        return True

def getSystemAdministratorPassword(noprompt):
    """Get the system administrator password from the user via environment
    variable or interactively.

    Returns:
        UTF-8 encoded system administrator password
    """

    if checkRunningInstance():
        return None

    passwordFromEnvironmentVariable = os.environ.get(mssqlSaPasswordEnvVariable)

    # If MSSQL_SA_PASSWORD is not set, fall back to SA_PASSWORD
    #
    if passwordFromEnvironmentVariable is None:
        passwordFromEnvironmentVariable = os.environ.get(saPasswordEnvVariable)

    if (passwordFromEnvironmentVariable != None):
        if(validatePasswordAndPrintIfError(passwordFromEnvironmentVariable) == False):
            return errorExitCode
        return passwordFromEnvironmentVariable

    if not noprompt or noprompt == False:
        return getSystemAdministratorPasswordInteractive()

    print(_("The MSSQL_SA_PASSWORD environment variable must be set in order to change the"))
    print(_("system administrator password."))

    return None

def getSystemAdministratorPasswordInteractive():
    """Get the system administrator password from the user interactively.

    Returns:
        System administrator password
    """

    while True:
        saPassword = getpass.getpass(_("Enter the SQL Server system administrator password: "))

        if validatePasswordAndPrintIfError(saPassword) == False:
            continue

        saPasswordConfirm = getpass.getpass(_("Confirm the SQL Server system administrator password: "))

        if (saPassword != saPasswordConfirm):
            printError(_("The passwords do not match. Please try again."))
            continue

        break

    return saPassword

def isPreview():
    try:
        command = 'if   hash dpkg; then cmd="dpkg --list"; \
                   elif hash  rpm; then cmd="rpm -qa";\
                   else exit 1; fi; \
                   $cmd | grep mssql-server | grep preview'
        subprocess.check_call(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
    except:
        return False
    return True

def validatePid(pid):
    """Validate a product key

    Args:
        pid(str): Product key

    Returns:
        Product key if valid, otherwise None
    """

    if not (
        pid.lower() == expressEdition or
        pid.lower() == evaluationEdition or
        pid.lower() == developerEdition or
        pid.lower() == webEdition or
        pid.lower() == standardEdition or
        pid.lower() == enterpriseEdition or
        pid.lower() == enterpriseCoreEdition or
        re.match("^([A-Za-z0-9]){5}-([A-Za-z0-9]){5}\-([A-Za-z0-9]){5}\-([A-Za-z0-9]){5}\-([A-Za-z0-9]){5}$", pid)
    ):
        printError(_("Invalid PID specified: %s.") % (pid))
        print("")
        return None

    return pid

def getPidFromEditionSelected(edition):
    """Gets the correct pid to pass to the engine

    Args:
        edition(string): Edition option 1-10

    Returns:
        Pid as expected by the engine

    """

    if edition == "1":
        return evaluationEdition
    elif edition == "2":
        return developerEdition
    elif edition == "3":
        return expressEdition
    elif edition == "4":
        return webEdition
    elif edition == "5":
        while(True):
            if getPIDCoveredBySA():
                break
        return standardEdition
    elif edition == "6":
        while(True):
            if getPIDCoveredBySA():
                break        
        return enterpriseEdition
    elif edition == "7":
        while(True):
            if getPIDCoveredBySA():
                break
        return enterpriseCoreEdition
    elif edition == "8":
        while(True):
            productKey = input(_("Enter the 25-character product key: "))
            print("")
            if validatePid(productKey):
                break

        while(True):
            if getPIDCoveredBySA():
                break
        return productKey
    elif edition == "9":
        return standardEdition
    elif edition == "10":
        return enterpriseCoreEdition

    else:
        print((_("Invalid option %s.") % edition))
        exit(errorExitCode)

def getEditionAzureBilled(edition):
    """Set the setting to indiciate whether the edition is one covered by Software Assurance.

    Args:
        edition(string): Edition option 1-10

    Returns:
        True if the input and setting was successful.  False if there was an error.
    """
    isAzureBilled = "false"
    if edition == "9" or edition == "10":
        isAzureBilled = "true"

    return setSettingByName(licensingConfigSection, isAzureBilledSetting, isAzureBilled)

def getPIDCoveredBySA():
    """If the user inputs a product key ask if the key is covered by Software Assurance.

    Args:
        None

    Returns:
        True if the input and setting was successful.  False if there was an error.

    """
    pidSACoveredInput = input(_("Is the product selected covered by Software Assurance?") + " [Yes/No]:")
    
    print("")
    if (pidSACoveredInput.strip().lower() == "yes" or pidSACoveredInput.strip().lower() == "y"):
        return setSettingByName(licensingConfigSection, productCoveredBySASetting, pidPositiveResponse)
    elif (pidSACoveredInput.strip().lower() == "no" or pidSACoveredInput.strip().lower() == "n"):
        return setSettingByName(licensingConfigSection, productCoveredBySASetting, pidNegativeResponse)

    return setSettingByName(licensingConfigSection, productCoveredBySASetting, pidInvalidResponse)

def getPreviewPid(noprompt=False):
    """Let the user know how long they have to test the preview build, and 
       return the pid for the preview. We re-use the same pid as evaluation, to re-use
       the existing time bomb mecanisms.

    Args:
        noprompt(bool): Don't prompt the user if True, simply print out the information.

    Returns:
        Product key

    """
    
    previewString = 'This is a preview version (free, no production use rights, 180-day limit starting now)'
    
    if noprompt:
        print(previewString)
    else:
        agreement = input(previewString + ', continue? [Yes/No]:')
        print("")
        if (agreement.strip().lower() != "yes" and agreement.strip().lower() != "y"):
            exit(errorExitCode)
    return evaluationEdition

def getPid(noprompt=False):
    """Get product key from user

    Args:
        noprompt(bool): Don't prompt user if True

    Returns:
        Product key
    """

    pidFromEnv = os.environ.get(mssqlPidEnvVariable)

    if (pidFromEnv != None):
        return validatePid(pidFromEnv)

    # If running with --noprompt and MSSQL_PID not set return developer edition
    #
    if (noprompt):
        return developerEdition

    print(_("Choose an edition of SQL Server:"))
    print(("  1) Evaluation " + _("(free, no production use rights, 180-day limit)")))
    print(("  2) Developer " + _("(free, no production use rights)")))
    print(("  3) Express " + _("(free)")))
    print(("  4) Web " + _("(PAID)")))
    print(("  5) Standard " + _("(PAID)")))
    print(("  6) Enterprise " + _("(PAID)") + " - " + _("CPU core utilization restricted to 20 physical/40 hyperthreaded")))
    print(("  7) Enterprise Core " + _("(PAID)") + " - " + _("CPU core utilization up to Operating System Maximum")))
    print(("  8) " + _("I bought a license through a retail sales channel and have a product key to enter.")))
    print(("  9) Standard " + _("(Billed through Azure)") + " - " + _("Use pay-as-you-go billing through Azure.")))
    print((" 10) Enterprise Core " + _("(Billed through Azure)") + " - " + _("Use pay-as-you-go billing through Azure.")))
    print("")
    print(_("Details about editions can be found at"))
    print((getFwlinkWithLocale("2109348")))
    print("")
    print(_("Use of PAID editions of this software requires separate licensing through a"))
    print(_("Microsoft Volume Licensing program."))
    print(_("By choosing a PAID edition, you are verifying that you have the appropriate"))
    print(_("number of licenses in place to install and run this software."))
    print(_("By choosing an edition billed Pay-As-You-Go through Azure, you are verifying "))
    print(_("that the server and SQL Server will be connected to Azure by installing the "))
    print(_("management agent and Azure extension for SQL Server."))
    print("")
    edition = input(_("Enter your edition") + "(1-10): " )

    pid = getPidFromEditionSelected(edition)
    getEditionAzureBilled(edition)
    return validatePid(pid)

def configureSqlservrWithArguments(*args, **kwargs):
    """Configure SQL Server with arguments

    Args:
        args(str): Parameters to SQL Server
        kwargs(dict): Environment variables

    Returns:
        SQL Server exit code
    """

    args = [invokeSqlservrScript] + list(args)
    env = dict(os.environ)
    env.update(kwargs)
    env = {str(k): str(v) for k, v in list(env.items())}
    print(_("Configuring SQL Server..."))
    return subprocess.call(args, env=env)

def runScript(pathToScript, runAsRoot=False):
    """Runs a script (optionally as root)

    Args:
        pathToScript(str): Path to script to run
        runAsRoot(boolean): Run script as root or not

    Returns:
        Script exit code
    """

    logger.info("Running script: [%s]", pathToScript)
    if runAsRoot and not checkSudo():
        printError(_("Elevated privileges required for this action. Please run in 'sudo' mode or as root."))
        return errorExitCode

    # Capture all output to the piped stdout.
    #
    process = subprocess.run([pathToScript], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    try:
        output = process.stdout.decode("utf-8")
        logger.info("Script gave output: [%s]", output)
    except:
        # This should never happen, but this is just for logging so fail gracefully.
        #
        logger.exception("Could not process script output as utf-8")

    return process.returncode

def checkInstall(runAsRoot=True):
    """Checks installation of SQL Server

    Returns:
        True if there are no problems, False otherwise
    """

    return runScript(checkInstallScript, runAsRoot) == 0

def checkRunningInstance():
    """Check if an instance of SQL Server is running

    Returns:
        True if there is an instance running, False otherwise
    """

    ret = runScript(checkRunningInstanceScript)

    if (ret == 0):
        print(_("An instance of SQL Server is running. Please stop the SQL Server service"))
        print(_("using the following command"))
        print ("")
        print ("    sudo systemctl stop mssql-server")
        return True

    return False

def setupSqlServer(eulaAccepted, noprompt=False):
    """Setup and initialize SQL Server

    Args:
        eulaAccepted (boolean): Whether Eula was accepted on command line or via env variable
        noprompt (boolean): Don't prompt if True
    """

    # Make sure installation basics are OK
    #
    if not checkInstall():
        exit(errorExitCode)

    # Check if SQL Server is already running
    #
    if checkRunningInstance():
        exit(errorExitCode)

    # Get product key
    #
    if not isPreview():
        pid = getPid(noprompt)
    else:
        pid = getPreviewPid(noprompt)

    if(noprompt == False and pid is None):
        exit(errorExitCode)

    # Check for EULA acceptance and show EULA based on edition selected
    if not checkEulaAgreement(eulaAccepted, configurationFilePath, isEvaluationEdition = (pid == evaluationEdition)):
        printError(_("EULA not accepted. Exiting."))
        exit(errorExitCode)

    # Select language and write LCID to configuration
    #
    languageSelect(noprompt)

    # Set settings from environment
    #
    if (not setSettingsFromEnvironment()):
        exit(errorExitCode)

    # Get system administrator password
    #
    encodedPassword = getSystemAdministratorPassword(noprompt)
    if (encodedPassword == errorExitCode or encodedPassword == None):
        exit(errorExitCode)

    if (pid == None):
        ret = configureSqlservrWithArguments("--setup --reset-sa-password", MSSQL_SA_PASSWORD=encodedPassword)
    else:
        ret = configureSqlservrWithArguments("--setup --reset-sa-password", MSSQL_SA_PASSWORD=encodedPassword, MSSQL_PID=pid)

    if (ret == errorExitCode):
        print(_("Initial setup of Microsoft SQL Server failed. Please consult the ERRORLOG"))
        print(_("in %s for more information.") % (getErrorLogFile(configurationFilePath)))
        exit(ret)

    # Start the SQL Server service
    #
    ret = subprocess.call(["systemctl", "start", "mssql-server"])
    if (ret != 0):
        print(_("Attempting to start the Microsoft SQL Server service failed."))
        exit(ret)

    # Enable SQL Server to run at startup
    #
    ret = subprocess.call(["systemctl", "enable", "mssql-server"])
    if (ret != 0):
        print(_("Attempting to enable the Microsoft SQL Server to start at boot failed."))
        exit(ret)

    print(_("Setup has completed successfully. SQL Server is now starting."))

    return None

def isMlServicesInstalled():
    """Check whether machine learning services has been installed.
    """
    if os.path.exists(launchpaddPath) and os.path.exists(mlservicesPathRoot):
        return True

    return False

def checkInstallExtensibility():
    """Checks installation of SQL Server Extensibility service

    Returns:
        True if there are no problems, False otherwise
    """

    return runScript(checkInstallExtensibilityScript, True) == 0

def setupSqlServerMlServices(eulaMlAccepted):
    """Setup and initialize SQL Server Machine Learning services (Launchpadd, R and/or python)

    Args:
        eulaMlAccepted (boolean): Whether Eula was accepted on command line or via env variable
    """

    # Make sure installation basics are OK
    #
    if not checkInstallExtensibility():
        exit(errorExitCode)

    # Check for ML EULA acceptance
    if not checkEulaMlAgreement(eulaMlAccepted, configurationFilePath):
        printError(_("EULA for machine learning services not accepted. Exiting."))
        exit(errorExitCode)

    # Start the SQL Server ML service
    #
    ret = subprocess.call(["systemctl", "start", "mssql-launchpadd"])
    if (ret != 0):
        print(_("Attempting to start Microsoft SQL Server Machine Learning Services failed."))
        exit(ret)

    # Enable SQL Server ML service to run at startup
    #
    ret = subprocess.call(["systemctl", "enable", "mssql-launchpadd"])
    if (ret != 0):
        print(_("Attempting to enable Microsoft SQL Server Machine Learning Services to start at boot failed."))
        exit(ret)

    print(_("Setup has completed successfully. Microsoft SQL Server Machine Learning Services is now starting."))
    print("")

    printRestartRequiredMessage()

    return None

def isValidPassword(password):
    """Determines if a given password matches SQL Server policy requirements

    Args:
        password (str): a password encoded as described by sys.stdin.encoding

    Returns:
        Integer error code. 0 for success, [1-6] otherwise. Error codes are defined
        in the passwordErrorCodes class.

    Notes:
        This function cannot print anything to stdout or it will interfere with
        the validate-password flag in mssql-conf.py
    """

    # Characters that Windows treats as "special characters" for its password validation.
    #
    SYMBOL_CHARACTERS = "(`~!@#$%^&*_-+=|\\{}[]:;\"'<>,.?)/"

    # Known bug:
    # This function will return the incorrect character count for multi-byte encodings
    # For example, a smiley face emoji, U+1F601 will cause this function to allow shorter
    # passwords. However, this matches the behavior of Windows password policy validation,
    # and will thus cause no compatability issues or information loss with SQL Server.
    #
    if len(password) < 8:
        return passwordErrorCodes.TOO_SHORT
    if len(password) > 128:
        return passwordErrorCodes.TOO_LONG

    # Exit if we cannot convert the password to UTF-8 for the PAL
    #
    try:
        testOutput = password.encode('utf-8', 'surrogatepass')
    except (UnicodeEncodeError, UnicodeDecodeError):
        return passwordErrorCodes.ENCODING_ERROR

    # Count the number of occurences of each type of character
    #
    containsUpper = 0
    containsLower = 0
    containsNumeric = 0
    containsSymbol = 0

    for unicodeChar in password:
        # Gets the general category assigned to a given unicode character
        # If it does not fall into any of these categories (For example, Sm, which includes U+221A)
        # then it is treated as a character for length, but does not count toward any category.
        #
        # This does not perfectly match the way that Windows evaluates characters, but is a
        # proper subset of characters accepted by Windows
        #
        # If a character is a "control character" like escape, then we print an error.
        #
        char_category = unicodedata.category(unicodeChar)
        if unicodeChar in SYMBOL_CHARACTERS:
            containsSymbol = 1
        elif char_category == 'Ll':
            containsLower = 1
        elif char_category == 'Lu':
            containsUpper = 1
        elif char_category == 'Nd':
            containsNumeric = 1
        elif char_category == 'Cc':
            return passwordErrorCodes.CONTROL_CHARS

    if (containsUpper + containsLower + containsNumeric + containsSymbol) >= 3:
        return passwordErrorCodes.SUCCESS

    else:
        return passwordErrorCodes.NOT_COMPLEX

def _printTextInColor(text, bcolor):
    """_printTextInColor

    Args:
        text(str): Text to print
        bcolor(int): Color to use
    """

    if (checkColorSupported()):
        print((bcolor + text + bcolors.ENDC))
    else:
        print(text)

# Provide a non-builtin definition for _ so that the typechecker (such as pylance)
# does not complain about it.
#
locstrings = None
def localizeText(origText):
    if locstrings is None:
        return origText
    else:
        return locstrings.gettext(origText)

# Can't just do def "_(origText): ..." because it breaks pygettext
#
_ = localizeText

def initialize():
    """Initialize mssqlconfhelper
    """
    global locstrings

    try:
        defaultMoFilePath = os.path.dirname(os.path.realpath(__file__)) + "/loc/mo/mssql-conf-en_US.mo"
        locale.setlocale(locale.LC_ALL, '')
        localeCode = locale.getlocale()[0]

        if (localeCode == None):
            moFilePath = defaultMoFilePath
        else:
            moFilePath = os.path.dirname(os.path.realpath(__file__)) + "/loc/mo/mssql-conf-" + localeCode + ".mo"
            if (os.path.isfile(moFilePath) == False):
                print("Locale %s not supported. Using en_US." % localeCode)
                moFilePath = defaultMoFilePath
    except:
        print ("Error in localization. Using en_US.")
        moFilePath = defaultMoFilePath

    # Don't use locstrings.install() because it doesn't play nice
    # with typecheckers
    #
    locstrings = gettext.GNUTranslations(open(moFilePath, "rb"))

def getErrorLogFile(configFilePath = configurationFilePath):
    """Get error log file
    """
    config = ConfigParser()
    readConfigFromFile(config, configFilePath)

    try:
        errorlog = config.get("filelocation", "errorlogfile")
    except:
        errorlog = sqlPathLogDir

    return errorlog
